In [None]:
import os
import glob
import feather
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
import anndata
import numpy as np
import copy
import openpyxl
import seaborn as sns

In [None]:
def add_ncounts_ngenes(ad): 
    # add counts per well
    ad.obs['n_counts'] = ad.X.sum(axis=1)

    # add genes per well
    ad.obs['n_genes'] = ad.X.astype('bool').sum(axis=1)
    
def import_feather(path, seq_type, sample_batch):
    """
    Converts feather objects to useable anndata. Ignores data feather objects other than $readcount$exon$all.

    Args:
        path: Path to directory with all feather objects
        
    Returns:
        Anndata object containing raw read counts per gene
    """
    
    DATA_DIR = 'FFPE/seq_data/'
    
    FILE_ensembl_to_gene = DATA_DIR+path+"/ensemblID_to_gene.feather"
    FILE_matrix_genes = DATA_DIR+path+"/matrix_genes.feather"
    FILE_counts = DATA_DIR+path+"/zUMI_dataframe_counts.feather"
    FILE_umi = DATA_DIR+path+"/zUMI_dataframe_umi.feather"
    FILE_sample_sheet = DATA_DIR+path+"/sample_sheet.txt"

    # load sample sheet
    sample_sheet = pd.read_csv(FILE_sample_sheet, sep="\t")
    if 'PCRcycles' in sample_sheet.columns: 
        sample_sheet['PCR_cycles'] = sample_sheet['PCRcycles'].copy()

    # load ensembl ID to gene name conversion 
    gene_id_to_name_DF = feather.read_dataframe(FILE_ensembl_to_gene)
    gene_name_dict = pd.Series(gene_id_to_name_DF.gene_name.values,index=gene_id_to_name_DF.gene_id).to_dict()

    # load matrix genes & replace ensembl ID with gene name
    matrix_genes_DF = feather.read_dataframe(FILE_matrix_genes)
    matrix_genes_DF['gene_name'] = ''
    for idx, row in matrix_genes_DF.iterrows(): 
        matrix_genes_DF.at[idx,'gene_name'] = gene_name_dict[row[0]]

    # load barcode to well position conversion 
    if 'sample' not in sample_sheet.columns: 
        if 'Sample' in sample_sheet.columns: 
            sample_sheet['sample'] = sample_sheet['Sample'].copy()
            
    bc_to_well = dict(zip(sample_sheet["concat_bc"], sample_sheet["sample"]))

    # load raw read count dataframe (if UMIs available, use it)
    if os.path.isfile(FILE_umi): 
        DF = feather.read_dataframe(FILE_umi)
    else: 
        DF = feather.read_dataframe(FILE_counts)
        
    DF['gene_name'] = matrix_genes_DF['gene_name']  # add gene name column
    DF.set_index('gene_name',inplace=True)  # set gene name as index
    DF.rename(columns=bc_to_well,inplace=True) # rename columns from barcodes to well positions
    umi_cts = DF.transpose()
    umi_cts.fillna(0,inplace = True)

    # build obs dataframe with sample attribute values
    obs_df = sample_sheet.set_index('sample')
    obs_df = obs_df[[x for x in sample_sheet.columns if x not in ["sample","concat_bc"]]]
    tokeep = [x for x in umi_cts.index]
    obs_df = obs_df.loc[tokeep]
    obs_df["seq_type"] = seq_type
    obs_df["sample_batch"] = sample_batch

    # build var dataframe
    var_df = pd.DataFrame([],index=umi_cts.columns)

    # build anndata object
    cts_ad = anndata.AnnData(X=umi_cts.values, obs=obs_df, var = var_df)
    cts_ad.var_names_make_unique()
    
    add_ncounts_ngenes(cts_ad)

    print(cts_ad.obs.head(1))
    print(cts_ad.shape)
    
    return cts_ad

In [None]:
def split_hg19_and_mm10(ad):
    def calc_cts(ad, cts_type):
        ad.obs['n_counts_%s' %cts_type] = ad.X.sum(axis=1)
        return ad
    def calc_genes(ad, gene_type): 
        ad.obs['n_genes_%s' %gene_type] = ad.X.astype(bool).sum(axis=1)
        return ad
    
    hg19_genes = [g for g in ad.var_names if g.startswith('hg19_')]
    hg19_gdict = dict(zip(hg19_genes,[g.lstrip('hg19') for g in hg19_genes]))
    for key, item in hg19_gdict.items(): 
        hg19_gdict[key] = item[1:]
        
    mm10_genes = [g for g in ad.var_names if g.startswith('mm10_')]
    mm10_gdict = dict(zip(mm10_genes,[g.lstrip('mm10') for g in mm10_genes]))
    for key, item in mm10_gdict.items(): 
        mm10_gdict[key] = item[1:]
    
    ad_hg19 = ad[:,hg19_genes]
    ad_hg19.var.rename(index=hg19_gdict, inplace=True)
    ad_hg19.var_names = ad_hg19.var.index
    print(ad_hg19.shape)
    
    ad_mm10 = ad[:,mm10_genes]
    ad_mm10.var.rename(index=mm10_gdict, inplace=True)
    ad_mm10.var_names = ad_mm10.var.index
    print(ad_mm10.shape)
    
    calc_cts(ad_hg19, 'hg19')
    calc_genes(ad_hg19, 'hg19')
    calc_cts(ad_mm10, 'mm10')
    calc_genes(ad_mm10, 'mm10')
    
    ad_all = ad.copy()
    ad_all.obs['n_counts_hg19'] = ad_hg19.obs['n_counts_hg19']
    ad_all.obs['n_genes_hg19'] = ad_hg19.obs['n_genes_hg19']
    ad_all.obs['n_counts_mm10'] = ad_mm10.obs['n_counts_mm10']
    ad_all.obs['n_genes_mm10'] = ad_mm10.obs['n_genes_mm10']
    
    ad_all.obs['percent_mm10'] = 100*np.sum(ad_all[:, mm10_genes].X, axis=1) / np.sum(ad_all.X, axis=1)
    ad_all.obs['percent_hg19'] = 100*np.sum(ad_all[:, hg19_genes].X, axis=1) / np.sum(ad_all.X, axis=1)
    
    ad_mm10.obs['percent_mm10'] = ad_all.obs['percent_mm10']
    ad_mm10.obs['percent_hg19'] = ad_all.obs['percent_hg19']
    
    print('Returned ad_hg19, ad_mm10, ad_all')
    return ad_hg19, ad_mm10, ad_all

In [None]:
def write_adata_to_csvs(ad, filename, writemode='default'):
    import pandas as pd
    
    out_dir = 'write/submission_files'
    
    # gene expression matrix
    if writemode=='default':
        df = pd.DataFrame(data=ad.X, 
                     index=ad.obs_names, 
                     columns=ad.var_names)
    elif writemode=='counts':
        df = pd.DataFrame(data=ad.layers['counts'].toarray(),
                     index=ad.obs_names, 
                     columns=ad.var_names)
    df.T.to_csv('%s/%s_matrix.csv' %(out_dir, filename))

    # gene names
    ad.var.to_csv('%s/%s_var.csv' %(out_dir, filename))

    # barcodes / nuclei and metadata
    ad.obs.to_csv('%s/%s_obs.csv' %(out_dir, filename))

In [None]:
def calc_mito_ncounts(ad, case='mouse'):
    # calculate mitochondrial fraction
    if case=='human': 
        mito_genes = ad.var_names.str.startswith('MT-') # human
    if case=='hg19': 
        mito_genes = ad.var_names.str.startswith('hg19_MT-') # human
    elif case=='mouse': 
        mito_genes = ad.var_names.str.startswith('mt-') #mouse
    
    ad.obs['frac_mito'] = np.sum(ad[:, mito_genes].X, axis=1) / np.sum(ad.X, axis=1)

    # add the total counts per cell as observations-annotation to adata
    ad.obs['n_counts'] = ad.X.sum(axis=1)
    return ad

In [None]:
def construct_batch_logical_gate(ad, blims, feat='n_genes'):
    import operator
#     batches = list(set(ad.obs['batch']))
    batches = ad.obs['batch'].cat.categories
    combined_bool = ad.obs['batch']==None
    for batch in batches: 
        lims_b = blims[batch]
        bool_b = (ad.obs['batch']==batch) & \
                    (ad.obs[feat]>lims_b[0]) & \
                    (ad.obs[feat]<lims_b[1])
        combined_bool = operator.or_(combined_bool, bool_b)
    return combined_bool

In [None]:
def get_cells_per_gene(ad): 
    return ad.X.toarray().astype(bool).sum(axis=0)

In [None]:
def uniform_jitter(obj_len, minval=-2, maxval=2): 
    return np.random.uniform(minval,maxval,obj_len)

def scatter_hg19_vs_mm10(ad, hue_var='plate', savefig=False): 
    colors = sns.color_palette()
    fig, ax = plt.subplots(figsize=(5,5))
    g = sns.scatterplot(x=ad.obs['percent_hg19'] + uniform_jitter(ad.obs.shape[0]),
                y=ad.obs['percent_mm10'] + uniform_jitter(ad.obs.shape[0]),
                data=ad.obs,
                hue=ad.obs[hue_var],
                s=12,
                edgecolor='k',
                linewidth=0.3,
                alpha=0.5)
    g.legend(loc='upper right', ncol=1, bbox_to_anchor=(1.5, 0.5))
    plt.xlim([-10,110])
    plt.ylim([-10,110])
    plt.xlabel('% hg19 gene counts',fontsize=16)
    plt.ylabel('% mm10 gene counts',fontsize=16)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    if savefig: 
        plt.savefig('%s/scatter_mm10_hg19.pdf' %(sc.settings.figdir), bbox_inches='tight')


In [None]:
def scatter_UMI_genes_hist(ad, samplestr, density=True, savefig=True): 
    from scipy.stats import gaussian_kde

    # definitions for the axes
    left, width = 0.1, 0.65
    bottom, height = 0.1, 0.65
    spacing = 0.005

    rect_scatter = [left, bottom, width, height]
    rect_histx = [left, bottom + height + spacing, width, 0.2]
    rect_histy = [left + width + spacing, bottom, 0.2, height]

    h = plt.figure(figsize=(3,3), dpi=1200)

    ax_scatter = plt.axes(rect_scatter)
    ax_scatter.tick_params(direction='in', top=True, right=True)
    ax_histx = plt.axes(rect_histx)
    ax_histx.tick_params(direction='in', labelbottom=False)
    ax_histy = plt.axes(rect_histy)
    ax_histy.tick_params(direction='in', labelleft=False)

    # plot x=y line
    ax_scatter.plot([0,6],[0,6],'k--',linewidth=1)
    
    x = np.log10(ad.obs['n_counts'])
    y = np.log10(ad.obs['n_genes'])
    
    # calculate density
    if density: 
        xy = np.vstack([x,y])
        z = gaussian_kde(xy)(xy)
        ax_sc = ax_scatter.scatter(x, y, c=z, s=3, alpha=0.5, cmap='coolwarm')
        plt.colorbar(ax_sc)
    
    else: 

        ax_scatter.scatter(x, y, s=3, color='#696969', alpha=0.5)
    
    bins = np.arange(0,6,0.05)
    ax_histx.hist(x, bins=bins, facecolor='#696969')
    ax_histy.hist(y, bins=bins, orientation='horizontal', facecolor='#696969')

    # set axis properties 
    ax_scatter.set_xlim([1,5])
    ax_scatter.set_xticks([1,2,3,4,5])
    ax_scatter.set_ylim([1,4])
    ax_scatter.set_yticks([1,2,3,4])
    ax_scatter.set_xlabel('log10(n_UMIs)', fontsize=12)
    ax_scatter.set_ylabel('log10(n_genes)', fontsize=12)
    
    ax_histx.set_xlim(ax_scatter.get_xlim())
    ax_histy.set_ylim(ax_scatter.get_ylim())
    
    ax_histx.spines['top'].set_visible(False)
    ax_histx.spines['right'].set_visible(False)
    ax_histx.spines['left'].set_visible(False)
    ax_histx.set_xticks([])
    ax_histx.set_yticks([])
    
    ax_histy.spines['top'].set_visible(False)
    ax_histy.spines['right'].set_visible(False)
    ax_histy.spines['bottom'].set_visible(False)
    ax_histy.set_xticks([])
    ax_histy.set_yticks([])

    if savefig: 
        plt.savefig('%s/scatter_ngenes_UMIs_hist_%s.png' %(sc.settings.figdir, samplestr), bbox_inches='tight')

In [None]:
def make_gray_monoscale_cmap(): 
    import matplotlib
    from matplotlib import cm
    blues = cm.get_cmap('Blues', 200)
    blues_array = blues(np.linspace(0, 1, 15)).tolist()
    blues_array.insert(0, [0.85, 0.85, 0.85, 1.0])
    bg = matplotlib.colors.ListedColormap(blues_array,name='blues_with_gray')
    return bg

In [None]:
def load_pyannotable(): 
    import pickle5 as pickle
    with open('pyannotables.pickle', 'rb') as handle:
        tables = pickle.load(handle)
    return tables

def load_gene_len_dict(species):
    import pickle5 as pickle
    
    if species=='mouse':
        table_type = 'mus_musculus-GRCm38-ensembl100'
        dict_name = 'mouse_gene_lengths.pickle'
    elif species=='human':
        table_type = 'homo_sapiens-GRCh38-ensembl100'
        dict_name = 'human_gene_lengths.pickle'
    
    if os.path.isfile(dict_name): 
        with open(dict_name, 'rb') as handle:
            gene_len_dict = pickle.load(handle)
        
    else: 
        tables = load_pyannotable()
        gene_len_dict = {}
        table = tables[table_type]
        genes = tables[table_type].gene_name.unique()
        for g in genes: 
            hits = table[table['gene_name']==g]
            if hits.shape[0]>1: 
                gene_lens = list(set(table[table['gene_name']==g]['gene_coding_length']))
                if len(gene_lens)>1:
                    gene_len = max(gene_lens)
                else: 
                    gene_len = gene_lens[0]
            else: 
                gene_len = hits['gene_coding_length'].item()
            gene_len_dict[g] = gene_len
            
        # save
        with open(dict_name, 'wb') as handle:
            pickle.dump(gene_len_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
            
    genes_df = pd.DataFrame(gene_len_dict.items(), columns=['Gene','Length'])
    genes_df.set_index('Gene', inplace=True)

    return genes_df

In [None]:
def get_pathway_genes(pathway_type=['c2']): 
    pathway_genes = {}
    pathways = []
    
    for path in pathway_type: 
        fh = open('%s_pathways.txt' %path)
        Lines = fh.readlines()
        
        for line in Lines: 
            a = line.strip().split(',')
            pathway = a[0]
            pathways.append(pathway)

            genes = a[1:]
            pathway_genes[pathway] = genes
            
        fh.close()
    print('Loaded %i pathways' %(len(pathways)))
    return pathway_genes

In [None]:
def program_counts(ad, genes): 
    vals = ad[:,genes].layers['counts'].toarray().sum(axis=1)
    vals_norm = vals/len(genes)
    return vals_norm

In [None]:
def weighted_program_counts(ad, genes, gene_weights): 
    cts = ad[:,genes].layers['counts'].toarray()
    weighted_sum = [(c_i/gene_weights).sum() for c_i in cts]
    return weighted_sum

In [None]:
def get_gene_pathway_ct(pathway_dict): 
    reverse_dict = {}
    for k,v in sorted(pathway_dict.items()):
        for x in v:
            reverse_dict.setdefault(x,[]).append(k)
    gene_pathway_ct = {k:len(v) for k,v in sorted(reverse_dict.items())}
    
    return gene_pathway_ct

In [None]:
def weighted_pathway_aggregation(ad, pathways, pathway_dict, gene_pathway_ct): 
    import numpy as np
    import anndata
    
    pathway_scores = np.empty((ad.shape[0], len(pathways)))
    for idx, pway in zip(range(len(pathways)), pathways): 
        p_genes = pathway_dict[pway]
        gene_weights = [gene_pathway_ct[g] for g in p_genes]
        pathway_scores[:,idx] = weighted_program_counts(ad, p_genes, gene_weights)
    
    var_df = pd.DataFrame(pathways, columns=['pathway'])
    var_df.set_index('pathway', inplace=True)
    ad_path = anndata.AnnData(X=pathway_scores,
                                obs=ad.obs[['Sample',
                                             'sample_batch',
                                             'n_counts_hg19',
                                             'n_genes_hg19',
                                             'leiden',
                                             'plate']],
                               var=var_df) 
    
    return ad_path

In [None]:
def get_pathways_anndata(ad, pathway_genes, MIN_FRAC=0.75, MIN_GENES=10, MAX_GENES=200, weighted=False):
    import anndata
    import pandas as pd
    
    ad_pathways = ad.copy()
    valid_pathways = []
    valid_genes = []
    n_pathway_components = []
    updated_dict = {}

    print('Keeping only valid genes and pathways')
    for path in sorted(pathway_genes.keys()): 
        path = path.strip()
        valid_genes = [g for g in pathway_genes[path] if g in ad_pathways.var_names]

        n_genes = len(valid_genes)
        if (n_genes > MIN_GENES) & (n_genes < MAX_GENES):
            if float(n_genes)/len(pathway_genes[path]) > MIN_FRAC: 
                if path not in valid_pathways: # remove duplicate pathway names
                    valid_pathways.append(path)
                    n_pathway_components.append(len(valid_genes))

                    # calculate program score (raw counts)
                    ad_pathways.obs['%s' %path] = program_counts(ad_pathways, valid_genes)

                    updated_dict[path] = valid_genes
    
    print(len(valid_pathways))
#     valid_pathways = list(set(valid_pathways))
    
    # make new anndata object
    if weighted: 
        print('Making *weighted* anndata object')
        gene_pathway_ct = get_gene_pathway_ct(updated_dict)
        ad_path = weighted_pathway_aggregation(ad_pathways, 
                                            valid_pathways, 
                                            updated_dict, 
                                            gene_pathway_ct)
        
    else: 
        print('Making unweighted anndata object')
        pathway_scores = ad_pathways.obs[valid_pathways]
        var_df = pd.DataFrame(valid_pathways, columns=['pathway'])
        var_df.set_index('pathway', inplace=True)
        ad_path = anndata.AnnData(X=pathway_scores.to_numpy(),
                        obs=ad_pathways.obs[['Sample','sample_batch',
                                             'batch','plate',
                                            'n_counts','n_genes',
                                            'n_counts_hg19','n_genes_hg19']],
                       var=var_df) 
    
    return ad_path, updated_dict, valid_pathways

In [None]:
def plot_genes_heatmap(ad, glist, savefig=False, savestr='', max_scale=5, pd_colors=None):
    sns.reset_orig()
    glist = [g for g in glist if g in ad.var_names]
    ad1 = ad[:,glist].copy()
    sc.pp.normalize_per_cell(ad1, counts_per_cell_after=1e4)
    sc.pp.log1p(ad1)
    sc.pp.scale(ad1)

    cts = pd.DataFrame(ad1.X, index=ad1.obs.index)
    plt.figure()
    cg = sns.clustermap(cts.T, 
               xticklabels=[],
               yticklabels=ad1.var_names,
               col_colors=pd_colors,
               vmin=-max_scale, vmax=max_scale,
              cmap='RdBu_r')
    cg.ax_row_dendrogram.set_visible(False)
    cg.ax_col_dendrogram.set_visible(False)
    cg.ax_heatmap.set_xlabel('Nuclei')
    cg.ax_heatmap.set_ylabel('Marker genes')
    
    if savefig: 
        plt.savefig('%s/clustermap_%s.pdf' %(sc.settings.figdir, savestr))
    plt.show()

In [None]:
def get_cluster_proportions(adata,
                            cluster_key="leiden",
                            sample_key="sample_type",
                            drop_values=None):
    """
    Input
    =====
    adata : AnnData object
    cluster_key : key of `adata.obs` storing cluster info
    sample_key : key of `adata.obs` storing sample/replicate info
    drop_values : list/iterable of possible values of `sample_key` that you don't want
    
    Returns
    =======
    pd.DataFrame with samples as the index and clusters as the columns and 0-100 floats
    as values
    """
    
    adata_tmp = adata.copy()
    sizes = adata_tmp.obs.groupby([cluster_key, sample_key]).size()
    props = sizes.groupby(level=1).apply(lambda x: 100 * x / x.sum()).reset_index() 
    props = props.pivot(columns=sample_key, index=cluster_key).T
    props.index = props.index.droplevel(0)
    props.fillna(0, inplace=True)
    
    if drop_values is not None:
        for drop_value in drop_values:
            props.drop(drop_value, axis=0, inplace=True)
    return props


def plot_cluster_proportions(cluster_props, 
                             cluster_palette=None,
                             xlabel_rotation=0): 
    import seaborn as sns
    fig, ax = plt.subplots(figsize=(8,6))
    fig.patch.set_facecolor("white")
    
    cmap = None
    if cluster_palette is not None:
        cmap = sns.palettes.blend_palette(
            cluster_palette, 
            n_colors=len(cluster_palette), 
            as_cmap=True)
   
    cluster_props.plot(
        kind="bar", 
#         stacked=True, 
        ax=ax, 
        legend=None, 
        colormap=cmap
    )
    
    fontsize=16
    ax.legend(bbox_to_anchor=(1.1, 1.2), frameon=False, title="Assay type", fontsize=fontsize)
    sns.despine(fig, ax)
    ax.tick_params(axis="x", rotation=xlabel_rotation)
    ax.set_xlabel(cluster_props.index.name.capitalize(), fontsize=fontsize)
    ax.set_ylabel("% of nuclei in cluster")
    ax.set_xticklabels(cluster_props.index, rotation = 90, fontsize=fontsize)
    ax.set_yticks([0,50,100])
    ax.set_yticklabels([0,50,100], fontsize=fontsize)
    
    fig.tight_layout()
    
    return fig

In [None]:
def jaccard(list1, list2):
    intersection = len(list(set(list1).intersection(list2)))
    union = (len(list1) + len(list2)) - intersection
    return float(intersection) / union

In [None]:
def calculate_jaccard_pathways(pathways, pathway_dict): 
    import pandas as pd
    
    path_ij = np.zeros((len(pathways),len(pathways)))
    values = []

    for i in range(len(pathways)):
        p1 = pathways[i]
        for j in range(i+1,len(pathways)):
            p2 = pathways[j]
            J = jaccard(pathway_dict[p1],pathway_dict[p2])
            path_ij[i,j] = J
            path_ij[j,i] = J
            values.append(J)

    path_ij = pd.DataFrame(path_ij, columns=pathways, index=pathways)
    
    return path_ij, values

In [None]:
def cluster_path_Jij(df_J): 
    
    import random
    random.seed(0)
    
    plt.figure(figsize=(20,20))
    sns.set(font_scale=0.2)
    cg = sns.clustermap(df_J,
                        yticklabels=True,
                        xticklabels=False,
                        vmin=0,
                        vmax=1,
                        method='ward',
                        cmap='Blues')
    cg.ax_col_dendrogram.set_visible(False)
    plt.savefig('%s/pathway_dendrogram_ward.pdf' %(sc.settings.figdir),
                   bbox_inches='tight')
    
    return cg

def extract_cluster_feats(cg):
    import random
    random.seed(0)
    
    # reordered pathway idx
    ordered_idx = cg.dendrogram_col.reordered_ind
#     reordered_paths = [pathways[i] for i in ordered_idx]

    # extract tree linkage from clustergram
    Z = cg.dendrogram_col.linkage
    return ordered_idx, Z

In [None]:
def cut_tree(Z, pathways, CUT_THRESHOLD=2.3): 
    import random
    random.seed(0)
    
    from scipy import cluster
    from scipy.cluster.hierarchy import fcluster, fclusterdata, dendrogram
    import pandas as pd
    
    # plot tree cutting threshold
    sns.reset_orig()
    plt.figure(figsize=(8,15))
    dn2 = dendrogram(Z, orientation='left', no_labels=True,
                     color_threshold=CUT_THRESHOLD,
                    above_threshold_color='#bcbcbc')
    plt.axvline(CUT_THRESHOLD, color='k', linestyle='--')
    plt.gca().invert_yaxis()
    plt.xlabel('Dendrogram distance')
    plt.savefig('%s/pathway_dendrogram_tree_cut_threshold.pdf' %(sc.settings.figdir),
                   bbox_inches='tight')

    assignments = fcluster(Z, CUT_THRESHOLD, criterion='distance')
    cluster_output = pd.DataFrame({'pathway':pathways, 'cluster':assignments})
    cluster_output.set_index('pathway', inplace=True)
    
    return cluster_output

In [None]:
from pylab import *

def make_pathway_cmap(cluster_assign, pathways): 
    cmap = cm.get_cmap('Set2', len(set(cluster_assign['cluster_reduced'])))
    colors = []
    for i in range(cmap.N):
        rgba = cmap(i)
        colors.append(matplotlib.colors.rgb2hex(rgba))

    np.random.shuffle(colors)
    color_dict = dict(zip(set(cluster_assign['cluster_reduced']), colors))
    color_dict[0] = '#ffffff'
    
    
    df_colors = pd.DataFrame(0, index=pathways, columns=['color'])
    for f, path in zip(cluster_assign['cluster_reduced'], pathways):
        df_colors.loc[path] = color_dict[f]
        
    return df_colors, color_dict

In [None]:
def load_GTEx_lung_data(): 
    import scanpy as sc
    import pandas as pd
    
    ad_gtex = sc.read('GTEx_data/GTEx_8_tissues_snRNAseq_atlas_071421.public_obs.h5ad')
    ad_gtex_lung = ad_gtex[ad_gtex.obs['tissue']=='lung'].copy()
    
    ad_gtex_lung.layers['scaled'] = ad_gtex_lung.X.copy()
    sc.pp.scale(ad_gtex_lung, layer='scaled', max_value=10)
    
    return ad_gtex_lung

In [None]:
def load_LUAD_data(): 
    import scanpy as sc
    import pandas as pd

    ad_LUAD = sc.read('GSE131907_raw.h5ad')
    LUAD_metadata = pd.read_csv('GSE131907_Lung_Cancer_cell_annotation.txt', delimiter='\t')
    LUAD_metadata.set_index('Index', inplace=True)
    ad_LUAD.obs = pd.concat([ad_LUAD.obs, LUAD_metadata], axis=1, join="inner")
    
    ad_LUAD.obs['n_counts'] = ad_LUAD.obs['nCount_RNA'].copy()
    ad_LUAD.obs['n_genes'] = ad_LUAD.obs['nFeature_RNA'].copy()
    ad_LUAD.obs['Sample ID'] = ad_LUAD.obs['Sample'].copy()
    
    ad_LUAD_lung = ad_LUAD[ad_LUAD.obs['Sample_Origin'].isin(['tLung','nLung']),:].copy()
    ad_LUAD_lung.layers['counts'] = ad_LUAD_lung.X.copy()
    
    # match adata.obs['"Cell types level 2"']
    broad_class_dict = {'B lymphocytes':'Immune (lymphocyte)',
                         'Endothelial cells':'Endothelial cell',
                         'Epithelial cells':'Epithelial cell',
                         'Fibroblasts':'Fibroblast',
                         'MAST cells':'Immune (myeloid)',
                         'Myeloid cells':'Immune (myeloid)',
                         'T/NK cells':'Immune (lymphocyte)'
                       }
    ad_LUAD_lung.obs['Cell types level 2'] = ad_LUAD_lung.obs['Cell_type.refined'].map(broad_class_dict)
    
    # match 'Broad cell type (numbers)'
    cell_type_dict = {'AT1':'5. Epithelial cell (alveolar type I)',
                    'AT2':'6. Epithelial cell (alveolar type II)',
                    'Activated DCs':'21. Immune (DC/macrophage)',
                    'Alveolar Mac':'25. Immune (alveolar macrophage)',
                    'CD141+ DCs':'21. Immune (DC/macrophage)',
                    'CD163+CD14+ DCs':'21. Immune (DC/macrophage)',
                    'CD1c+ DCs':'21. Immune (DC/macrophage)',
                    'CD207+CD1a+ LCs':'21. Immune (DC/macrophage)',
                    'CD4+ Th':'24. Immune (T cell)',
                    'CD8 low T':'24. Immune (T cell)',
                    'CD8+/CD4+ Mixed Th':'24. Immune (T cell)',
                    'COL13A1+ matrix FBs':'17. Fibroblast',
                    'COL14A1+ matrix FBs':'17. Fibroblast',
                    'Ciliated':'9. Epithelial cell (ciliated)',
                    'Club':'10. Epithelial cell (club)',
                    'Cytotoxic CD8+ T':'24. Immune (T cell)',
                    'EPCs':'3. Endothelial cell (vascular)',
                    'Exhausted CD8+ T':'24. Immune (T cell)',
                    'Exhausted Tfh':'24. Immune (T cell)',
                    'FB-like cells':'17. Fibroblast',
                    'Follicular B cells':'19. Immune (B cell)',
                    'GC B cells in the DZ':'19. Immune (B cell)',
                    'GC B cells in the LZ':'19. Immune (B cell)',
                    'GrB-secreting B cells':'19. Immune (B cell)',
                    'Lymphatic ECs':'2. Endothelial cell (lymphatic)',
                    'MALT B cells':'19. Immune (B cell)',
                    'MAST':'26. Immune (mast cell)',
                    'Malignant cells':'44. Malignant cells', # new
                    'Mesothelial cells':'45. Mesothelial cells', # new
                    'Microglia/Mac':'21. Immune (DC/macrophage)',
                    'Monocytes':'21. Immune (DC/macrophage)',
                    'Myofibroblasts':'17. Fibroblast',
                    'NK':'23. Immune (NK cell)',
                    'Naive CD4+ T':'24. Immune (T cell)',
                    'Naive CD8+ T':'24. Immune (T cell)',
                    'Pericytes':'39. Pericyte/SMC',
                    'Plasma cells':'19. Immune (B cell)',
                    'Pleural Mac':'21. Immune (DC/macrophage)',
                    'Smooth muscle cells':'39. Pericyte/SMC',
                    'Stalk-like ECs':'3. Endothelial cell (vascular)',
                    'Tip-like ECs':'3. Endothelial cell (vascular)',
                    'Treg':'24. Immune (T cell)',
                    'Tumor ECs':'40. Endothelial cell (tumor)', # new
                    'mo-Mac':'21. Immune (DC/macrophage)',
                    'pDCs':'21. Immune (DC/macrophage)',
                    'tS1':'41. tumor State 1', # new
                    'tS2':'42. tumor State 2', # new
                    'tS3':'43. tumor State 3' # new
                 } 
    ad_LUAD_lung.obs['Cell type (numbers)'] = ad_LUAD_lung.obs['Cell_subtype'].map(cell_type_dict)
    
    ad_LUAD_lung.obs.drop(columns=['nCount_RNA','nFeature_RNA', 'Barcode', 'orig.ident'], inplace=True)
    
    ad_LUAD_lung.layers['counts'] = ad_LUAD_lung.X.copy()
    sc.pp.normalize_total(ad_LUAD_lung)
    sc.pp.log1p(ad_LUAD_lung)
    sc.pp.scale(ad_LUAD_lung, max_value=10)
    ad_LUAD_lung.layers['scaled'] = ad_LUAD_lung.X.copy()
    
    return ad_LUAD_lung

In [None]:
def get_cell_type_markers(ad, group_name, num_genes=50, test_method='t-test', use_raw=True): 
    if group_name == 'GTEx':
        group_name = 'Cell types level 2'
    elif group_name == 'LUAD_scRNA':
        group_name = 'Cell_type'
        
    sc.tl.rank_genes_groups(ad, group_name, method=test_method, use_raw=use_raw)
    cluster_markers = pd.DataFrame(ad.uns['rank_genes_groups']['names']).head(num_genes)
    marker_dict = cluster_markers.to_dict(orient='list')
    
    return marker_dict, cluster_markers

In [None]:
def valid_markers(marker_dict, ad_target): 
    marker_dict_valid = {}

    for k,v in marker_dict.items(): 
        valid_genes = [g for g in v if g in ad_target.var_names]
        print(k, len(valid_genes))
        marker_dict_valid[k] = valid_genes
        
    return marker_dict_valid

In [None]:
def score_with_markers(marker_dict_valid, ad_score, target_clusters, score_base): 
    for clust in target_clusters: 
        sc.tl.score_genes(ad_score, marker_dict_valid[clust], score_name='%s_%s' %(score_base, clust))

    score_names = [score_base+'_'+clust for clust in target_clusters]
    
    for i in range(len(score_names)):
        plt.figure()
        ad_score.obs[score_names[i]].hist(bins=50, alpha=0.6)
        plt.xlabel(score_names[i])
        plt.axvline(0, color='k')
        plt.grid(False)
        plt.savefig('%s/cell_type_score_%s_%s.pdf' %(sc.settings.figdir, score_base, score_names[i]))
    
    ad_score.obs['rank_assignment'] = ad_score.obs[score_names].rank(ascending=False).idxmin(axis=1)
    ad_score.obs['max_score_assignment'] = ad_score.obs[score_names].idxmax(axis=1)
    
    return ad_score, score_names
