In [1]:
import numpy as np
import pandas as pd
import scanpy.api as sc
%matplotlib inline
import matplotlib.pyplot as plt

import sys 
import inspect
import seaborn as sns
import os

sc.settings.verbosity = 3  # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.settings.set_figure_params(dpi=80, color_map='viridis')
sc.logging.print_versions()


In a future version of Scanpy, `scanpy.api` will be removed.
Simply use `import scanpy as sc` and `import scanpy.external as sce` instead.



scanpy==1.4.5.post1 anndata==0.7.1 umap==0.3.9 numpy==1.17.2 scipy==1.4.1 pandas==0.24.2 scikit-learn==0.22 statsmodels==0.11.0 python-igraph==0.7.1 louvain==0.6.1


In [2]:
def subsample_counts(adata_here,subsampling_factor):
    import time
    from scipy.sparse import csr_matrix
    from scipy.sparse import coo_matrix
    #quickly loop through the sparse matrix of counts
    new_vals=[]
    nonzeros=adata_here.X.nonzero()
    m=adata_here.X
    nonzero_r=nonzeros[0]
    nonzero_c=nonzeros[1]
    start = time.time()
    
    m.eliminate_zeros()
    vals=m.data
    num_elts=len(vals)
    m_subsampled_data=[]#np.random.binomial(value,subsampling_prob)
    elt=0
    subsampling_prob=subsampling_factor
    print('nonzeros',num_elts)

    while elt<num_elts:
        if elt%5000000==0:
            end = time.time()
            print(elt,(end - start))
            start = time.time()
        m_subsampled_data.append(np.random.binomial(vals[elt],subsampling_prob,1)[0])
        elt+=1
    downsampled=csr_matrix((m_subsampled_data, m.indices, m.indptr), dtype=float,shape=m.shape)
    print(np.median(np.array(downsampled.sum(axis=1))))
    return(downsampled)

#subsample to the same depth
def subsample_cells(adata_here,num_cells,grouping_variable):
    import random
    cells_keep=[]
    groups=list(set(adata_here.obs[grouping_variable]))
    for group in groups:
        group_cells=list(adata_here.obs_names[adata_here.obs[grouping_variable]==group])
        if len(group_cells)<num_cells:
            print('warning: fewer cells than needed for '+group+'. skipping subsampling')
        else:
            group_cells=random.sample(group_cells,num_cells)
        for cell in group_cells:
            cells_keep.append(cell)
    return(adata_here[cells_keep,:])

In [3]:
def get_sorted_list(df,x,y):
    medians=[]
    xs=list(set(df[x]))
    for m_idx in range(len(xs)):
        m=xs[m_idx]
        vals=df[df[x]==m][y]
        medians.append(np.median(vals))
    sort_idx=np.argsort(medians)
    sorted_xs=[]
    for i in range(len(xs)):
        sorted_xs.append(xs[sort_idx[i]])
    return(sorted_xs)

def sorted_catplot(df,x,y,huename,huedict,transpose_axes=False,plottype="violin",figwidth=5,figheight=10):
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    fig.set_size_inches(figwidth,figheight)
    sorted_xs=get_sorted_list(df,x,y)
    
    if not transpose_axes:
        g=sns.catplot(x=x, y=y, kind=plottype, data=df,ax=ax1,
                      order=sorted_xs,dodge=False,#inner='quartiles',color='lightgray',
                      hue=huename,palette=huedict,inner='quartiles',
                     legend_out=True,cut=0) #scale='width',
    else:
        print('here')
        g=sns.catplot(x=y, y=x, kind=plottype, data=df,ax=ax1,
                      order=sorted_xs,scale='width',dodge=False,hue=huename,palette=huedict,inner='quartiles',
                     cut=0)
    ax1.set_xticklabels(ax1.get_xticklabels(), rotation=90)
    #ax1.grid(False)
    ax1.get_legend().remove()
    plt.close(g.fig)
    
def sorted_catplot_given_order(df,x,y,sorted_xs,huename,huedict,transpose_axes=False,plottype="violin",figwidth=5,figheight=10):
    fig = plt.figure()
    ax1 = fig.add_subplot(111)
    fig.set_size_inches(figwidth,figheight)
    
    if not transpose_axes:
        g=sns.catplot(x=x, y=y, kind=plottype, data=df,ax=ax1,
                      order=sorted_xs,dodge=False,#inner='quartiles',color='lightgray',
                      hue=huename,palette=huedict,inner='quartiles',
                     legend_out=True,cut=0) #scale='width',
    else:
        print('here')
        g=sns.catplot(x=y, y=x, kind=plottype, data=df,ax=ax1,
                      order=sorted_xs,scale='width',dodge=False,hue=huename,palette=huedict,inner='quartiles',
                     cut=0)
    ax1.set_xticklabels(ax1.get_xticklabels(), rotation=90)
    #ax1.grid(False)
    ax1.get_legend().remove()
    plt.close(g.fig)

In [4]:
def build_bulk(adata_here,grouping_variable,by_batch=True):
    
    """Compute an in silico bulk set of expression profiles, based on cell labels
 
    Parameters
    ----------
    adata_here : `scanpy Anndata`
    grouping_variable : `str`
        The name of the variable that specifies a label for each cell. This variable must be accessible as `adata_here.obs[grouping_variable]`
    by_batch : `bool`
        Whether to combine data from cells with the same label but from different batches.
        If this is set to True, adata_here must have a adata_here.obs["batch"]
    
    Returns
    -------
    profile_matrix_df : a pandas DataFrame of size (number of conditions) x (number of genes). 
                        The number of conditions is: number of unique labels in `adata_here.obs[grouping_variable]` if by_batch==False
                                                     number of unique labels times the number of batches if by batch==True
    """
    
    #construct the profiles
    profiles=list(set(adata_here.obs[grouping_variable]))
    adata_here.obs['profile']=adata_here.obs[grouping_variable]
    
    if by_batch:
        profile_list=[]
        #make a new variable that combines batch and variable into 1
        for cell_idx in range(len(adata_here.obs_names)):
            profile=adata_here.obs['batch'][cell_idx]+'_'+adata_here.obs[grouping_variable][cell_idx]
            profile_list.append(profile)
        adata_here.obs['profile']=profile_list
        profiles=list(set(profile_list))
        
    genes=adata_here.var_names
    profile_matrix=np.zeros((len(profiles),len(genes)))
    for profile_idx in range(len(profiles)):
        profile=profiles[profile_idx]
        cells_with_profile=list(adata_here.obs_names[adata_here.obs['profile']==profile])
        data_profile=adata_here[cells_with_profile,:].X.toarray()
        profile_matrix[profile_idx,:]=data_profile.mean(axis=0)
    profile_matrix_df=pd.DataFrame(profile_matrix)
    profile_matrix_df.index=profiles
    profile_matrix_df.columns=genes
    return(profile_matrix_df)

def get_corr_mat(df,axis=0,corr_type='spearman'):
    
    if axis==0: #do it for rows
        df_here=np.array(copy.deepcopy(df))
    if axis==1:
        df_here=np.array(copy.deepcopy(df.T))
        
    corr_mat=np.zeros((df_here.shape[0],df_here.shape[0]))
    for i in range(df_here.shape[0]):
        #if i%10==0:
         #   display_progress(i,df_here.shape[0])
        for j in range(i,df_here.shape[0]):
            a=df_here[i,:]
            b=df_here[j,:]
            if np.std(a)==0 or np.std(b)==0:
                continue
            if corr_type=='pearson':
                val=pearsonr(a,b)[0]
            if corr_type=='spearman':
                val=spearmanr(a,b)[0]
            if corr_type=='diff':
                val=np.sum(np.abs(a-b))
            corr_mat[i,j]=val
            corr_mat[j,i]=val
    return(corr_mat)


In [None]:
def plot_multiple_densities(adata_subset,variable_for_groups,groups,group_plot_dict,color_dict,
                            scatter_s=3,scatter_alpha=1):
    coords=pd.DataFrame({'umap1':adata_subset.obsm['X_umap'][:,0],
                    'umap2':adata_subset.obsm['X_umap'][:,1],
                    'group':adata_subset.obs[variable_for_groups]})
    
    fig, plots = plt.subplots()
    fig.set_size_inches(5,5)
    title_text=''
    for group in groups:
        
        group_cmap='Reds'
        if group in color_dict:
            group_cmap=color_dict[group]
        group_plot='contour'
        if group in group_plot_dict:
            group_plot=group_plot_dict[group]
        if group_plot=='shade':
            shading=True
        if group_plot=='contour':
            shading=False
        if group_plot in ['contour','shade']:
            stuff=sns.jointplot(x="umap1", y="umap2", data=coords[coords['group']==group], 
                            kind="kde",
                  ax=plots,cmap=group_cmap,
                 shade=shading);
            sns.mpl.pyplot.close()
        if group_plot=='scatter':
            plots.scatter(coords[coords['group']==group]['umap1'],
                          coords[coords['group']==group]['umap2'],s=scatter_s, alpha=scatter_alpha,
                          color=group_cmap)     
            
        title_text=title_text+'\n'+group+':'+group_cmap
    plots.grid(False)
    plots.set_title(title_text)
    plt.show()
    