In [None]:
import multiprocessing as mp
import itertools as it
import functools as ft
import pickle
import sys
import numpy as np
import pandas as pd
import time
import sklearn
import sklearn.preprocessing as pre
import scipy.sparse as sp

In [None]:
dat = pd.read_csv('./cord_blood_kinases.csv', sep=',', header=0, index_col=0);

In [None]:
def get_cell_sets(row, oe_csr):
    return oe_csr[row['lower']:row['upper']].sum(axis=0)

def first_candidates(cells, cell_sets, min_shared_cells):
    count_filter = cell_sets.apply(len) > min_shared_cells
    return list(map(lambda x: frozenset([x]), ((cells[count_filter])))), {frozenset([x]):y for x,y in cell_sets[count_filter].to_dict().items()}

def intersector(tuple_of_candidates, tuple_of_sets):
    return ft.reduce(lambda x,y: x.union(y), tuple_of_candidates), tuple_of_sets[0] & tuple_of_sets[1]

def cell_set_getter(input_list, cell_sets):
    for i in input_list:
        yield cell_sets[i]
        
def make_gener(left, right, min_shared_cells, cell_sets, q):
    left_gen = cell_set_getter(left, cell_sets)
    right_gen = cell_set_getter(right, cell_sets)        
    gener = ((x, y) for x, y in map(intersector, *(zip(left, right),zip(left_gen, right_gen))) if len(y)>min_shared_cells)
    q.put(zip(*gener))
    return 
        
def pickle_cells(cells, cell_sets, k):
    '''These files are gonna be decently big. Do not want to keep them in memory.'''
    with open('cell_' + str(k) + '.pickle', 'wb') as f:
        pickle.dump(cells, f, pickle.HIGHEST_PROTOCOL)
    with open('cell_sets_' + str(k) + '.pickle', 'wb') as f:
        pickle.dump(cell_sets, f, pickle.HIGHEST_PROTOCOL)

In [None]:
def zip_helper(gener_slice, q):
    q.put(list(gener_slice))
    return

In [None]:
def fast_gather_gene_sets(dat, min_shared_cells = 100, min_percent_cells = None, max_cluster_size = sys.maxsize):
    st = time.time()
    begin = st
    cores = max(mp.cpu_count()-1, 1)
    
    total_cells = dat['barcode'].nunique()
    
    if(min_percent_cells is not None):
        min_shared_cells = int(min_percent_cells * total_cells)

    cell_id_dict = {y:x for x,y in enumerate(dat['symbol'].unique())}
    dat['symbol'] = dat['symbol'].map(cell_id_dict)
    cells = dat['symbol'].unique()
    
    barcode_id_dict = {y:x for x,y in enumerate(dat['barcode'].unique())}
    dat['barcode'] = dat['barcode'].map(barcode_id_dict)
    
    cell_sets = dat.groupby('symbol')['barcode'].apply(set)
    
    en = time.time()
    
    print('Formatted data in ' + str(en-st) + ' seconds')
    
    cells, cell_sets = first_candidates(cells, cell_sets, min_shared_cells)
    
    print(str(len(cells)) + ' genes made have > ' + str(min_shared_cells) + ' cells')
    
    k = 2
    n = len(cells)
    
    pickle_cells(cells, cell_sets, k)
    
    while(len(cells) > 0 and k < max_cluster_size):
        st = time.time()
    
        candidates_left, candidates_right = zip(*list(filter(lambda x: len(x[0]|x[1]) == k, it.combinations(cells, 2))))
        
        cand_len = len(candidates_left)
        candidates_left = iter(candidates_left)
        candidates_right = iter(candidates_right)

        q = mp.JoinableQueue()
        kwarg_dict={'min_shared_cells':min_shared_cells,'cell_sets':cell_sets,'q':q}

        for i in range(cores-1):
            p = mp.Process(target=make_gener, args=(list(it.islice(candidates_left, cand_len//cores)), list(it.islice(candidates_right, cand_len//cores))), kwargs=kwarg_dict)
            p.start()

        for i in range(1):
            p = mp.Process(target=make_gener, args=(list(candidates_left), list(candidates_right)), kwargs=kwarg_dict)
            p.start()

        print('Finished launching processes in: '+ str(time.time()-st) + ' seconds')
            
        output = []
        for i in range(cores):
            output.append(q.get())
            q.task_done()

        cells, cell_sets = zip(*output)
        cells = [item for tup in cells for item in tup]
        cell_sets = [item for tup in cell_sets for item in tup]
        cell_sets = dict(zip(cells, cell_sets))
        
        k+= 1
        n = len(cells)
        
        en = time.time()
        
        print('Found ' + str(n) + ' remaining gene clusters with > ' + str(min_shared_cells) + ' of size: ' +str(k-1))
        print('Iteration took: ' + str(en-st) + ' seconds')
        print('Total time: ' + str(en - begin) + ' seconds')
        
        if(n == 0):
            print('Terminated! Total run time: ' + str(en - begin) + ' seconds')
        else:
            print('Pickling!')
            pickle_cells(cells, cell_sets, k-1)
        

In [None]:
fast_gather_gene_sets(dat, min_percent_cells = 0.01)