In [1]:
import multiprocessing as mp
import itertools as it
import functools as ft
import pickle
import sys
import numpy as np
import pandas as pd
import time
import sklearn
import sklearn.preprocessing as pre
import scipy.sparse as sp

In [2]:
(set([1, 2]), set([3, 4]))

({1, 2}, {3, 4})

In [3]:
temp = lambda x,y: frozenset(x | y)
ft.reduce(temp, list(map(lambda x: frozenset([x]), [1,2,3,4])))


frozenset({1, 2, 3, 4})

In [4]:
{**{1:'a', 2:'b'}, **{2:'b', 3:'d'}}

{1: 'a', 2: 'b', 3: 'd'}

In [5]:
def get_cell_sets(row, oe_csr):
    return oe_csr[row['lower']:row['upper']].sum(axis=0)

def first_candidates(cells, cell_sets, min_shared_cells):
    count_filter = cell_sets.apply(np.sum) > min_shared_cells
    return list(map(lambda x: frozenset([x]), ((cells[count_filter])))), {frozenset([x]):y for x,y in cell_sets[count_filter].to_dict().items()}

def intersector(tuple_of_candidates, cell_sets):
    new_candidate_maker = lambda x,y: x | y
    return new_candidate_maker(*tuple_of_candidates), np.logical_and(cell_sets[tuple_of_candidates[0]], cell_sets[tuple_of_candidates[1]])

def intersect_chunk(chunk_of_tuples, cell_sets, min_shared_cells, q):
    new_cell_set = dict([intersector(x, cell_sets) for x in chunk_of_tuples])
    new_cell_set = {x:y for x,y in new_cell_set.items() if np.sum(y) > min_shared_cells}
    q.put(new_cell_set)
    return

def pickle_cells(cells, cell_sets, k):
    '''These files are gonna be decently big. Do not want to keep them in memory.'''
    with open('cell_' + str(k) + '.pickle', 'wb') as f:
        pickle.dump(cells, f, pickle.HIGHEST_PROTOCOL)
    with open('cell_sets_' + str(k) + '.pickle', 'wb') as f:
        pickle.dump(cell_sets, f, pickle.HIGHEST_PROTOCOL)

In [6]:
def fast_gather_gene_sets(dat, min_shared_cells = 100, min_percent_cells = None, max_cluster_size = sys.maxsize):
    st = time.time()
    begin = st
    cores = max(mp.cpu_count()-1, 1)
    
    total_cells = dat['barcode'].nunique()
    
    if(min_percent_cells is not None):
        min_shared_cells = int(min_percent_cells * total_cells)

    cell_id_dict = {y:x for x,y in enumerate(dat['symbol'].unique())}
    dat['symbol'] = dat['symbol'].map(cell_id_dict)
    cells = dat['symbol'].unique()
    
    dat.sort_values(by='symbol',inplace=True)
    
    slices = pd.DataFrame(dat.groupby('symbol').count().cumsum())
    slices.columns = ['upper']
    slices['lower'] = [0] + list(slices['upper'])[0:-1]
    
    lab_enc = pre.LabelEncoder()
    one_hot = pre.OneHotEncoder(categories='auto')
    oe_data = one_hot.fit_transform((lab_enc.fit_transform(dat['barcode'].values)).reshape(-1,1))
    
    get_cell_partial = ft.partial(get_cell_sets, oe_csr=oe_data)
    
    cell_sets = slices.apply(get_cell_partial, axis=1)
    
    en = time.time()
    
    print('Formatted data in ' + str(en-st) + ' seconds')
    
    cells, cell_sets = first_candidates(cells, cell_sets, min_shared_cells)
    
    print(str(len(cells)) + ' genes made have > ' + str(min_shared_cells) + ' cells')
    
    k = 2
    n = len(cells)
    
    pickle_cells(cells, cell_sets, k)
    
    while(len(cells) > 0 and k < max_cluster_size):
        st = time.time()
        
        candidates_iter = filter(lambda x: len(set(x)) == k, it.combinations(cells, 2))
        
        q = mp.JoinableQueue()
        kwarg_dict = {'cell_sets':cell_sets, 'min_shared_cells':min_shared_cells, 'q':q}

        for i in range(cores-1):
            p = mp.Process(target=intersect_chunk, args=(it.islice(candidates_iter, n // cores),), kwargs=kwarg_dict)
            p.start()
            
        p = mp.Process(target=intersect_chunk, args=(candidates_iter,), kwargs=kwarg_dict)
        p.start()
        p.join()
        
        out = []
        for i in range(cores):
            out.append(q.get())
            q.task_done()
            
        print('Done with queue!`')
    
        q.join()
        q.close()
        
        cell_sets = ft.reduce(lambda x, y: {**x, **y}, out)
        cells = list(cell_sets.keys())
        
        k+= 1
        n = len(cells)
        
        en = time.time()
        
        print('Found ' + str(n) + ' remaining genes with > ' + str(min_shared_cells) + ' of size: ' +str(k))
        print('Iteration took: ' + str(en-st) + ' seconds')
        
        if(n == 0):
            print('Terminated! Total run time: ' + str(en - begin) + ' seconds')
        else:
            pickle_cells(cells, cell_sets, k-1)
        

In [7]:
dat = pd.read_csv('./cord_blood_kinases.csv', sep=',', header=0, index_col=0)

  mask |= (ar1 == a)


In [8]:
dat.head()

Unnamed: 0,barcode,symbol
1,MantonCB1_HiSeq_1-AAACCTGCACAGACAG-1,STK4
2,MantonCB1_HiSeq_1-AAACCTGCACAGACAG-1,RIOK3
3,MantonCB1_HiSeq_1-AAACCTGCACAGACAG-1,PDK2
4,MantonCB1_HiSeq_1-AAACCTGCACAGACAG-1,MAPK7
5,MantonCB1_HiSeq_1-AAACCTGCACAGACAG-1,CDK10


In [None]:
fast_gather_gene_sets(dat, min_percent_cells = 0.01)

Formatted data in 9.105247974395752 seconds
303 genes made have > 2738 cells
