In [3]:
import sys
sys.path.insert(0, '/cndd/fangming/CEMBA/snmcseq_dev')
import importlib

from __init__ import *
from __init__jupyterlab import *


import collections
from collections import deque
from scipy import stats
from scipy import optimize 
from scipy.optimize import curve_fit

import queue
# import tables
from scipy import sparse
from sklearn.model_selection import KFold
# from sklearn.decomposition import PCA
# from sklearn.neighbors import NearestNeighbors
# from sklearn.utils.sparsefuncs import mean_variance_axis
import fbpca
from statsmodels.stats.multitest import multipletests


import snmcseq_utils
importlib.reload(snmcseq_utils)
import CEMBA_clst_utils
importlib.reload(CEMBA_clst_utils)

# from CEMBA_run_tsne import run_tsne
# from CEMBA_run_tsne import run_tsne_v2

<module 'CEMBA_clst_utils' from '/cndd/fangming/CEMBA/snmcseq_dev/CEMBA_clst_utils.py'>

## Basic settings 
- use ```mods``` and ```settings[mod]``` to access modality specific information

In [4]:
mods_selected = [
    'snmcseq_gene',
    'snatac_gene',
    'smarter_cells',
    'smarter_nuclei',
    '10x_cells_v2', 
    '10x_cells_v3',
    '10x_nuclei_v3',
    '10x_nuclei_v3_macosko',
    ]

In [6]:
DATA_DIR = '/cndd/fangming/CEMBA/data/MOp_all/data_freeze_neurons'

# fixed dataset configs
sys.path.insert(0, DATA_DIR)
import __init__datasets
importlib.reload(__init__datasets)
from __init__datasets import *

meta_f = os.path.join(DATA_DIR, '{0}_metadata.tsv')
hvftrs_f = os.path.join(DATA_DIR, '{0}_hvfeatures.{1}')
hvftrs_gene = os.path.join(DATA_DIR, '{0}_hvfeatures.gene')
hvftrs_cell = os.path.join(DATA_DIR, '{0}_hvfeatures.cell')

In [7]:
metas = collections.OrderedDict()
for mod in mods_selected:
    metas[mod] = pd.read_csv(meta_f.format(mod), sep="\t").reset_index().set_index(settings[mod].cell_col)
    print(mod, metas[mod].shape)

snmcseq_gene (9366, 32)
snatac_gene (54844, 13)


  interactivity=interactivity, compiler=compiler, result=result)


smarter_cells (6244, 129)
smarter_nuclei (5911, 146)
10x_cells_v2 (121440, 8)
10x_cells_v3 (69727, 8)
10x_nuclei_v3 (39706, 8)
10x_nuclei_v3_macosko (101647, 19)


In [8]:
gxc_hvftrs = collections.OrderedDict()
for mod in mods_selected:
    print(mod)
    ti = time.time()
    
    if settings[mod].mod_category == 'mc':
        f_mat = hvftrs_f.format(mod, 'tsv')
        gxc_hvftrs[mod] = pd.read_csv(f_mat, sep='\t', header=0, index_col=0) 
        print(gxc_hvftrs[mod].shape, time.time()-ti)
        assert np.all(gxc_hvftrs[mod].columns.values == metas[mod].index.values) # make sure cell name is in the sanme order as metas (important if save knn mat)
        continue
        
        
    f_mat = hvftrs_f.format(mod, 'npz')
    f_gene = hvftrs_gene.format(mod)
    f_cell = hvftrs_cell.format(mod)
    _gxc_tmp = snmcseq_utils.load_gc_matrix(f_gene, f_cell, f_mat)
    _gene = _gxc_tmp.gene
    _cell = _gxc_tmp.cell
    _mat = _gxc_tmp.data
    
    gxc_hvftrs[mod] = GC_matrix(_gene, _cell, _mat)
    assert np.all(gxc_hvftrs[mod].cell == metas[mod].index.values) # make sure cell name is in the sanme order as metas (important if save knn mat)
    print(gxc_hvftrs[mod].data.shape, time.time()-ti)
    

snmcseq_gene
(4754, 9366) 29.89783787727356
snatac_gene
(6345, 54844) 4.769465446472168
smarter_cells
(5743, 6244) 1.6672003269195557
smarter_nuclei
(5400, 5911) 1.023179292678833
10x_cells_v2
(4067, 121440) 17.558100938796997
10x_cells_v3
(4694, 69727) 16.761533498764038
10x_nuclei_v3
(4150, 39706) 4.014025926589966
10x_nuclei_v3_macosko
(4194, 101647) 15.60694670677185


In [9]:
# check
for mod in mods_selected:
    if mod == 'snmcseq_gene':
        continue
        
    assert metas[mod].index.tolist() == gxc_hvftrs[mod].cell.tolist()

## Run iterative clustering

- Starting from ```gxc_hvftrs``` and ```settings```


### Functions - iterative clustering

In [None]:
def connect_nodes(edgelist):
    """Given edges (2-element set), return groups of nodes that are connected to each other
    UNDIRECTED
    A graph is a set
    """
    merged_graphs = []
    if not edgelist:
        return [] 
    
    for i, j in edgelist:
        # test if i or j is in graph
        graph_i = None
        graph_j = None
        for graph in merged_graphs:
            if i in graph:
                graph_i = graph
            if j in graph:
                graph_j = graph
                
        if graph_i and graph_j:
            if graph_i != graph_j:
                graph_i.update(graph_j)
                merged_graphs.remove(graph_j)
            else:
                pass
        
        elif graph_i:
            graph_i.add(j)
                    
        elif graph_j:
            graph_j.add(i)
                    
        else:
            merged_graphs.append({i, j})
        
    return merged_graphs 

def DE_genes(X_fr, X_bck, fc=2, alpha=0.05, option='one-sided', th1=0.3, th2=0.3):
    """Return number of DE genes
    X_fr: dataframe cellxgene
    X_bck: dataframe cellxgene
    th1: fraction of cells with CPM>1 for that gene (0.5) [P(g, fr) > 0.5]
    th2: [P(g, fr) - P(g, bck)]/P(g, fr) > 1 - th2
    gene axis should be the same between X_fr and X_bck 
    """
    assert X_fr.columns.tolist() == X_bck.columns.tolist()
    assert option in ['one-sided', 'two-sided']
    
    
    if option == 'one-sided':
        # logfc ()
        genes_logfc = (X_fr.mean(axis=0)) - (X_bck.mean(axis=0))
        # pvalue
        ps = []
        for col_idx, col in enumerate(X_fr):
            try:
                _, p = stats.mannwhitneyu(X_fr[col].values, 
                                          X_bck[col].values,
                                          alternative='greater',
                                          )
            except:
                p = 1
            ps.append(p)
            
    else: 
        # logfc ()
        genes_logfc = np.abs((X_fr.mean(axis=0)) - (X_bck.mean(axis=0)))
        # pvalue
        ps = []
        for col_idx, col in enumerate(X_fr):
            try:
                _, p = stats.mannwhitneyu(X_fr[col].values, 
                                          X_bck[col].values,
                                          alternative='two-sided',
                                          )
            except:
                p = 1
            ps.append(p)
    
    rejs, *_ = multipletests(ps, alpha=alpha, method='fdr_bh')
    
    # get number passing threshold
    sig_genes = ((genes_logfc>np.log10(fc)) & rejs)
    sig_genes_idx = np.arange(len(sig_genes))[sig_genes]
    
    frac_fr = (X_fr.iloc[:, sig_genes_idx] > np.log10(1+1)).sum(axis=0)/len(X_fr) # CPM > 1
    frac_bck = (X_bck.iloc[:, sig_genes_idx] > np.log10(1+1)).sum(axis=0)/len(X_bck) # CPM > 1
    
    sig_genes = sig_genes & (frac_fr > th1) & (frac_bck < th2*frac_fr)
    num_sig_genes = sig_genes.sum()
    
    return num_sig_genes

In [None]:
# iterative merging clusters (differential expression, # of cells)
def merge_clusters_iteratively(tmp_pcX, tmp_X, 
                           new_cells_batch_local_idx, new_cells_batch, 
                           num_sig_genes_th=10, n_min=30):
    """Iteratively call merge_clusters until there is no new merged cluster
    """
    nclst = len(new_cells_batch)
    updated_cells_batch_local_idx, updated_cells_batch = merge_clusters(tmp_pcX, tmp_X, 
                                                                        new_cells_batch_local_idx, 
                                                                        new_cells_batch, 
                                                                        num_sig_genes_th=num_sig_genes_th, 
                                                                        n_min=n_min)
    nclst_update = len(updated_cells_batch)
    if nclst_update == nclst: # no change
        return updated_cells_batch_local_idx, updated_cells_batch
    elif nclst_update < nclst: # some clusters are merged 
        return merge_clusters_iteratively(tmp_pcX, tmp_X, 
                                           updated_cells_batch_local_idx, updated_cells_batch, 
                                           num_sig_genes_th=num_sig_genes_th, n_min=n_min)
        
    elif nclst_update > nclst:
        raise ValueError("Should not get more clusters during cluster merging")
        

def merge_clusters(tmp_pcX, tmp_X, new_cells_batch_local_idx, new_cells_batch, num_sig_genes_th=8, n_min=30):
    """Given a list of clusters, see if each cluster can be merged with its nearest neighbor
    with certain DE gene and cluster size threshold
    
    tmp_pcX: cellxpc ndarray
    tmp_X: cellxgene dataframe
    new_cells_batch_local_idx: list of list of cell indices used to select rows in tmp_X and tmp_pcX
    new_cells_batch: matched with new_cells_batch_local_idx, provide global idx (X, pcX) of cells
    
    returns:
        - updated_cells_batch_local_idx: list of list of (local idx of cells)
        - updated_cells_batch: list of list of (global idx of cells)
    """
    if len(new_cells_batch) == 1:
        return new_cells_batch_local_idx, new_cells_batch
    elif len(new_cells_batch) < 1:
        raise ValueError("new_cells_batch can't be empty!")
        
    # centroids in PC space
    centroids = np.array([tmp_pcX[cell_idxs, :].mean(axis=0) for cell_idxs in new_cells_batch_local_idx]) #
    # nearest neighboring cluster for each cluster 
    clsts_nn_idx = (CEMBA_clst_utils.gen_knn_annoy(centroids, 2, form='list', metric='euclidean', verbose=False))[:, 1]

    mergelist = []
    merged = set()
    for clst_idx, clst_nn_idx in enumerate(clsts_nn_idx):
        if {clst_nn_idx, clst_idx} in mergelist:
            print("# DE genes skipped (previously done)")
            continue
            
            
        # test_fr against test_bck
        test_fr = new_cells_batch_local_idx[clst_idx]
        # its nearest cluster
        test_bck = new_cells_batch_local_idx[clst_nn_idx]
        num_sig_genes = DE_genes(tmp_X.iloc[test_fr, :], tmp_X.iloc[test_bck, :])
        print("# DE genes: {}".format(num_sig_genes))

        if num_sig_genes < num_sig_genes_th or len(test_fr) < n_min:
            # record merge clst, clst_nn
            mergelist.append({clst_idx, clst_nn_idx}) 
            merged.add(clst_idx)
            merged.add(clst_nn_idx)
        else:
            pass

    # merge these clusters
    mergelist = connect_nodes(mergelist)
    keeplist = [clst for clst in np.arange(len(new_cells_batch)) if clst not in merged]
    
    updated_cells_batch_local_idx = []
    updated_cells_batch = []
    # new_cells_batch
    for merge_case in mergelist:
        merge_case = list(merge_case)
        updated_cells_batch.append(np.hstack(np.array(new_cells_batch)[merge_case]))
        updated_cells_batch_local_idx.append(np.hstack(np.array(new_cells_batch_local_idx)[merge_case]))
    for clst_idx in keeplist:
        updated_cells_batch.append(new_cells_batch[clst_idx])
        updated_cells_batch_local_idx.append(new_cells_batch_local_idx[clst_idx])

    print("Merging: {}".format(mergelist))
#     print("Merged: {}".format(merged))
    print("Keeping: {}".format(keeplist))
    
    return updated_cells_batch_local_idx, updated_cells_batch

In [None]:
def iter_clst_routines_merging(X, npc=50, k=30, num_sig_genes_th=8, n_min=30):
    """X cellxgene dataframe logCPM
    
    X and pcX row index must be matched (represent the same cell) 
    
    npc and k will go low according to nclst
    """
    # setting
    n_rounds = 2

    # initiate
    cell_list = np.arange(len(X))
    queue_cells = deque([])
    queue_cells.append(cell_list)

    _round = 1
    queue_round = deque([])
    queue_round.append(_round)

    res_cells = []
    res_round = []

    ti = time.time()

    while queue_cells: 
        ### tmp_cells, tmp_round -> new_cells_batch, new_round
        ## pop 1 left 
        tmp_cells = np.array(queue_cells.popleft())
        tmp_cells_local_idx = np.arange(len(tmp_cells))
        tmp_round = queue_round.popleft()
        tmp_X = X.iloc[tmp_cells, :]

        # redo pca
        U, s, Vt = fbpca.pca(tmp_X.values, min(npc, len(tmp_cells)))
        if npc > len(tmp_cells):
            print("Actual number of principal components: {}".format(len(tmp_cells)))
        tmp_pcX = U.dot(np.diag(s)) 
        
        # clustering
        tmp_clsts = CEMBA_clst_utils.clustering_routine(tmp_pcX, tmp_cells_local_idx, k, 
                                       metric='euclidean', option='plain', n_trees=10, search_k=-1)
        # record new_cells_batch, new_round 
        new_cells_batch_local_idx = [] # local index
        new_cells_batch = [] # global index
        for clst, df_sub in tmp_clsts.groupby('cluster'):
            new_cells_batch_local_idx.append(df_sub.index.values)
            new_cells_batch.append(tmp_cells[df_sub.index.values])

        new_round = tmp_round + 1
        
        # iterative merging clusters (differential expression, # of cells)
        # tmp_pcX, tmp_X, new_cells_batch_local_idx, new_cells_batch_local_idx
        updated_cells_batch_local_idx, updated_cells_batch = merge_clusters_iteratively(
                                    tmp_pcX, tmp_X, 
                                    new_cells_batch_local_idx, 
                                    new_cells_batch, 
                                    num_sig_genes_th=num_sig_genes_th, 
                                    n_min=n_min)
        ## updated_cells_batch

        ## if we add them all back to queue or to final clusters 
        # keep split, enqueue
        if len(updated_cells_batch) > 1: # more clusters generated
            for new_cells in updated_cells_batch:
                queue_cells.append(new_cells)
                queue_round.append(new_round)
                
        # a final cluster, record
        else: # only 1 cluster in updated_cells_batch 
            new_cells = updated_cells_batch[0]
            res_cells.append(new_cells)
            res_round.append(new_round-1) # new_round is not done

        print("Total time since start: {} \n# left in queue: {}, # finalized: {}, # total: {}".format(
                time.time()-ti, len(queue_round), len(res_round), len(queue_round)+len(res_round)))

    return res_cells, res_round 

In [None]:
for mod in mods:
    _x = metas[mod]
    n = len(_x[_x[settings[mod].cluster_col]!=-1])
    print(mod, n)

n_subs_allmods = collections.OrderedDict({
    'snmcseq_gene': (np.array([1000, 2000, 3000, 4000, 4936])).astype(int), 
    'snatac_gene': (np.array([1000, 2000, 5000, 10000, 12000, 14000, 16000, 18000, 19021])).astype(int), 
    'smarter_cells': (np.array([1000, 2000, 3000, 4000, 5000, 6000, 6244])).astype(int), 
    'smarter_nuclei': (np.array([1000, 2000, 3000, 4000, 5000, 5911])).astype(int), 
    '10x_cells': (np.array([1000, 2000, 5000, 10000, 20000, 
                            30000, 40000, 50000, 60000, 70000, 80000, 90000, 100000, 110000, 117688])).astype(int), 
    '10x_nuclei': (np.array([1000, 2000, 5000, 10000, 20000, 30000, 40000, 50000, 60000, 70000, 77842])).astype(int), 
})

n_subs_allmods

In [None]:
# iterative clustering
n_repeat = 1 
ks = [30, 50, 100]
ks = [30]

for k in ks:
    for mod in mods: # mods
#         mod = '10x_cells' #### test 
        ti = time.time()
        
        # get X (a dataframe, cellxgene)
        if mod == 'snmcseq_gene':
            X = gxc_hvftrs[mod].T  # cellxgene
        else:
            X = pd.DataFrame(gxc_hvftrs[mod].data.T.todense(), 
                                index=gxc_hvftrs[mod].cell,
                                columns=gxc_hvftrs[mod].gene,
                                )  # cellxgene

        n_subs = n_subs_allmods[mod]
#         n_subs = [1000, 10000, 50000, 124418]  #### test
#         n_subs = [124418]  #### test
#         n_subs = [1000]  #### test
        for n_sub in n_subs: # subsample
            print("...{}".format(n_sub))

            ti = time.time()
            for i_repeat in range(n_repeat): # repeat subsample
                # subsample
                if n_sub < len(X):
                    X_sub = X.sample(n_sub, replace=False, 
        #                              random_state=1
                                    )
                    cell_list_sub = X_sub.index.values
                else:
                    X_sub = X
                    cell_list_sub = X_sub.index.values

                # iterative clustering 
                res_cells, _ = iter_clst_routines_merging(X_sub, k=k, num_sig_genes_th=5, n_min=30)
                df_clst = pd.DataFrame(index=cell_list_sub)
                df_clst['cluster'] = 0
                for i, res_cells_clst in enumerate(res_cells):
                    df_clst.iloc[res_cells_clst, 0] = i+1
                

                    
                output = '/cndd/fangming/CEMBA/results/clst_iter_downsamp_{}_{}_k{}_{}_v2'.format(mod, n_sub, k, i_repeat) 
                df_clst.to_csv(output, sep='\t', na_rep='NA', index=True, header=True)
                
                nclst = len(df_clst['cluster'].unique())
                print("Number of clusters: {}".format(nclst))
                
#             break # 1 subsampled dataset
#         break # 1 mod
    break # 1 k





### Leiden clustering resolution 

In [10]:
for mod in mods_selected:
    _x = metas[mod]
    n = len(_x[_x[settings[mod].cluster_col]!=-1])
    print(mod, n)

n_subs_allmods = collections.OrderedDict({
    'snmcseq_gene': (np.array([1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9366])).astype(int), 
    'snatac_gene': (np.array([1000, 2000, 5000, 10000, 20000, 40000, 54844])).astype(int), 
    'smarter_cells': (np.array([1000, 2000, 3000, 4000, 5000, 6000, 6244])).astype(int), 
    'smarter_nuclei': (np.array([1000, 2000, 3000, 4000, 5000, 5911])).astype(int), 
    '10x_cells_v2': (np.array([1000, 2000, 5000, 10000, 20000, 
                            30000, 40000, 50000, 60000, 70000, 80000, 90000, 100000, 110000, 121440])).astype(int), 
    '10x_cells_v3': (np.array([1000, 2000, 5000, 10000, 20000, 
                            30000, 40000, 50000, 60000, 69727])).astype(int), 
    '10x_nuclei_v3': (np.array([1000, 2000, 5000, 10000, 20000, 
                            30000, 39706])).astype(int), 
    '10x_nuclei_v3_macosko': (np.array([1000, 2000, 5000, 10000, 20000, 
                            30000, 40000, 50000, 60000, 70000, 80000, 90000, 101647])).astype(int), 
})

n_subs_allmods

snmcseq_gene 9366
snatac_gene 54844
smarter_cells 6244
smarter_nuclei 5911
10x_cells_v2 121440
10x_cells_v3 69727
10x_nuclei_v3 39706
10x_nuclei_v3_macosko 101647


OrderedDict([('snmcseq_gene',
              array([1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9366])),
             ('snatac_gene',
              array([ 1000,  2000,  5000, 10000, 20000, 40000, 54844])),
             ('smarter_cells',
              array([1000, 2000, 3000, 4000, 5000, 6000, 6244])),
             ('smarter_nuclei', array([1000, 2000, 3000, 4000, 5000, 5911])),
             ('10x_cells_v2',
              array([  1000,   2000,   5000,  10000,  20000,  30000,  40000,  50000,
                      60000,  70000,  80000,  90000, 100000, 110000, 121440])),
             ('10x_cells_v3',
              array([ 1000,  2000,  5000, 10000, 20000, 30000, 40000, 50000, 60000,
                     69727])),
             ('10x_nuclei_v3',
              array([ 1000,  2000,  5000, 10000, 20000, 30000, 39706])),
             ('10x_nuclei_v3_macosko',
              array([  1000,   2000,   5000,  10000,  20000,  30000,  40000,  50000,
                      60000,  70000,  80000,  9

In [None]:
# Louvain clustering resolution
n_repeat = 10 
k = 30
npc = 50

rs = [1, 2, 3, 4, 6, 8,]
print([str(r) for r in rs])
mods_used = mods_selected
output_format = '/cndd/fangming/CEMBA/results_neuron/clst_neuron_leiden_downsamp_{0}_{1}_{2}_{3}_190723.tsv'

"""Iterate through
1. mod (modalities)
2. n_subs (number of cells)
3. number of repeats (different subsampling)
4. clustering resolution

"""

for mod in mods_used: # mods
    ti = time.time()
    # get X (a dataframe, cellxgene)
    if mod == 'snmcseq_gene':
        X = gxc_hvftrs[mod].T 
    else:
        X = pd.DataFrame(gxc_hvftrs[mod].data.T.todense(), 
                        index=gxc_hvftrs[mod].cell,
                        columns=gxc_hvftrs[mod].gene,
                        ) 

    n_subs = n_subs_allmods[mod]
#     n_subs = [n_subs[-1]] ### test
    for n_sub in n_subs: # subsample
        print("...{}".format(n_sub))
        ti = time.time()
        for i_repeat in range(n_repeat): # repeat subsample
            # subsample
            if n_sub < len(X):
                X_sub = X.sample(n_sub, replace=False)
                cell_list_sub = X_sub.index.values
            else:
                X_sub = X
                cell_list_sub = X_sub.index.values

            for r in rs:  # stringency parameters
                # PCA
                U, s, Vt = fbpca.pca(X_sub.values, npc)
                pcX = U.dot(np.diag(s)) 
                # cell names
                cell_list = X_sub.index.values

                # Clustering 
                print(mod, len(cell_list), r)
                df_clst = CEMBA_clst_utils.clustering_routine(pcX, cell_list, k, 
                                                             resolution=r,
                                                             seed=1, verbose=False,
                                                             metric='euclidean', option='plain', 
                                                             n_trees=10, search_k=-1, num_starts=None)

                output = output_format.format(mod, n_sub, r, i_repeat) 
                df_clst.to_csv(output, sep='\t', na_rep='NA', index=True, header=True)

                nclst = len(df_clst['cluster'].unique())
                print("Number of clusters: {}".format(nclst))
                
#                 break # 1 r value
#             break # 1 repeat 
#         break # 1 subsample 
#     break # 1 mod 


['1', '2', '3', '4', '6', '8']
...1000
snmcseq_gene 1000 1
Number of clusters: 11
snmcseq_gene 1000 2
Number of clusters: 13
snmcseq_gene 1000 3
Number of clusters: 14
snmcseq_gene 1000 4
Number of clusters: 19
snmcseq_gene 1000 6
Number of clusters: 24
snmcseq_gene 1000 8
Number of clusters: 33
snmcseq_gene 1000 1
Number of clusters: 12
snmcseq_gene 1000 2
Number of clusters: 14
snmcseq_gene 1000 3
Number of clusters: 17
snmcseq_gene 1000 4
Number of clusters: 19
snmcseq_gene 1000 6
Number of clusters: 23
snmcseq_gene 1000 8
Number of clusters: 31
snmcseq_gene 1000 1
Number of clusters: 11
snmcseq_gene 1000 2
Number of clusters: 14
snmcseq_gene 1000 3
Number of clusters: 17
snmcseq_gene 1000 4
Number of clusters: 17
snmcseq_gene 1000 6
Number of clusters: 23
snmcseq_gene 1000 8
Number of clusters: 28
snmcseq_gene 1000 1
Number of clusters: 12
snmcseq_gene 1000 2
Number of clusters: 13
snmcseq_gene 1000 3
Number of clusters: 16
snmcseq_gene 1000 4
Number of clusters: 18
snmcseq_gene 10