In [1]:
import multiprocessing as mp
import itertools as it
import functools as ft
import pickle
import sys
import numpy as np
import pandas as pd
import time

In [2]:
def is_valid_candidate(tuple_of_previous_candidates, size_k):
    return len(tuple_of_previous_candidates[0] | tuple_of_previous_candidates[1]) == size_k

def big_valid_helper(chunk_of_candidates, size_k, q):
    is_v_cand = ft.partial(is_valid_candidate, size_k = size_k)
    out = set(filter(is_v_cand, chunk_of_candidates))
    if(len(out) == 0):
        q.put(None)
    else:
        q.put(out)
    return
   

def get_next_valid_candidates(set_of_previous_candidates, n, size_k):
    ''' Input: set of candidates from last iteration.
        Returns: list of tuples of new valid candidates for this iteration.
        Something something multiprocessing
    '''

    next_candidates = list(it.combinations(set_of_previous_candidates, 2))
    
    is_v_cand = ft.partial(is_valid_candidate, size_k = size_k)

    if(n > pow(mp.cpu_count(),3)):
        p = mp.Pool(processes=max(mp.cpu_count()-1,1))
        valid_candidates = p.map(is_v_cand, next_candidates)
        valid_candidates = it.compress(next_candidates, valid_candidates)
        p.close()
    else:
        valid_candidates = filter(is_v_cand, next_candidates)

    return valid_candidates

def check_new_candidate_count(tuple_of_new_candidates, last_candidate_cell_dict):
    '''takes in a tuple of new candidates, then evaluates if the intersection would have enough cells'''
    return len(intersect_candidate_cells(tuple_of_new_candidates, last_candidate_cell_dict))

def union_candidates(tuple_of_new_candidates):
    '''takes in a tuple of new candidates, then evaluates if the intersection would have enough cells'''
    return tuple_of_new_candidates[0] | tuple_of_new_candidates[1]

def check_greater(count, min_shared_cells):
    return count > min_shared_cells

def make_candidate_cell_dict(filtered_candidates, new_candidates, last_candidate_cell_dict):
    '''combines dict entries for the filtered candidates set and returns a new candidate dict matching the new_candidates'''
    combiner = ft.partial(intersect_candidate_cells, last_candidate_cell_dict=last_candidate_cell_dict)
    if(len(new_candidates) > pow(mp.cpu_count(),2)):
        p = mp.Pool(processes=max(mp.cpu_count()-1,1))
        out = dict(zip(new_candidates, list(map(combiner, filtered_candidates))))
        p.close()
        return out
    else:
        return dict(zip(new_candidates, list(map(combiner, filtered_candidates))))
    
def intersect_candidate_cells(filtered_candidates, last_candidate_cell_dict):
    '''just a helper for make_candidate_cell_dict for mapping'''
    return last_candidate_cell_dict[filtered_candidates[0]] & last_candidate_cell_dict[filtered_candidates[1]]
    
def expected_counter(tuple_of_new_candidates, last_candidate_cell_dict, total_cells):
    '''takes in a tuple of new candidates, then evaluates if the intersection would have enough cells'''
    return len(last_candidate_cell_dict[tuple_of_new_candidates[0]]) * len(last_candidate_cell_dict[tuple_of_new_candidates[1]]) / total_cells

def generate_new_candidates(valid_candidates, last_candidate_cell_dict, min_shared_cells, total_cells):
    '''heavy lifter. returns new candidates, new counts, new expected counts, and new candidate cell dicts'''
    counter = ft.partial(check_new_candidate_count, last_candidate_cell_dict=last_candidate_cell_dict)
    count_checker = ft.partial(check_greater,  min_shared_cells=min_shared_cells)
    exp_checker = ft.partial(expected_counter, last_candidate_cell_dict=last_candidate_cell_dict, total_cells=total_cells)

    s = time.time()
    
    if(len(valid_candidates) >  pow(mp.cpu_count(),2)):
        p = mp.Pool(processes=max(mp.cpu_count()-1,1))
        candidate_counts = list(p.map(counter, valid_candidates))
        p.close()
    else:
        candidate_counts = list(map(counter, valid_candidates))
        
    t = time.time()
    
    print(str(t - s) + ' seconds to count candidates')
          
    s = time.time()
 
    # it.compress and check_greater are fast enough
    candidate_filter = list(map(count_checker, candidate_counts))
    filtered_candidates = list(it.compress(valid_candidates, candidate_filter))
    
    if(len(filtered_candidates)==0):
        print('Ran out of candidates! ')
        return set(), set(), set(), {}
    
    # it.compress is fast enough
    new_counts = list(it.compress(candidate_counts, candidate_filter))
          
    t = time.time()
    
    print(str(t - s) + ' seconds to compress')
    
    s = time.time()
          
    if(len(filtered_candidates) >  pow(mp.cpu_count(),2)):
        p = mp.Pool(processes=max(mp.cpu_count()-1,1))
        new_expected_counts = list(p.map(exp_checker, filtered_candidates))
        new_candidates = list(p.map(union_candidates, filtered_candidates))
        p.close()
    else:
        new_expected_counts = list(map(exp_checker, filtered_candidates))
        new_candidates = list(map(union_candidates, filtered_candidates))
          
    t = time.time()
          
    print(str(t - s) + ' seconds to count expected values')
    
    # parallelized locally
    new_candidate_cell_dict = make_candidate_cell_dict(filtered_candidates, new_candidates, last_candidate_cell_dict)
        
    return new_candidates, new_counts, new_expected_counts, new_candidate_cell_dict

### storage helpers ###
def pickle_candidates(new_candidates, new_counts, new_expected_counts, new_candidate_cell_dict, size_k):
    '''These files are gonna be decently big. Do not want to keep them in memory.'''
    with open('candidates_' + str(size_k) + '.pickle', 'wb') as f:
        pickle.dump(zip(list(new_candidates), list(new_counts), list(new_expected_counts)), f, pickle.HIGHEST_PROTOCOL)
    with open('cell_dict_' + str(size_k) + '.pickle', 'wb') as f:
        pickle.dump(new_candidate_cell_dict, f, pickle.HIGHEST_PROTOCOL)
        
def unpickle_candidates(size_k):
    '''unpickler to unpickle the last one'''
    with open('candidates_' + str(size_k) + '.pickle', 'rb') as f:
        candidates, counts, expected_counts = zip(*pickle.load(f))
    with open('cell_dict_' + str(size_k) + '.pickle', 'rb') as f:
        candidate_cell_dict = pickle.load(f)
        
    return candidates, counts, expected_counts, candidate_cell_dict 

### these functions are for n_combinations large, e.g. when there are more candidates for evaluation than we'd like to run at the same time #
def giant_filter(tuple_of_candidates, size_k, last_cell_dict, min_shared_cells):
    if(not is_valid_candidate(tuple_of_candidates, size_k)):
        return None
    elif(len(intersect_candidate_cells(tuple_of_candidates, last_cell_dict)) <= min_shared_cells): 
        return None
    else:
        return tuple_of_candidates
    
def giant_candidater(tuple_of_valid_candidates, last_cell_dict, total_cells):
    cand = union_candidates(tuple_of_valid_candidates)
    new_cell_dict = {cand: intersect_candidate_cells(tuple_of_valid_candidates, last_cell_dict)}
    count = len(new_cell_dict)
    expected_count = last_cell_dict[tuple_of_valid_candidates[0]] * last_cell_dict[tuple_of_valid_candidates[1]] / total_cells

In [3]:
def gather_gene_sets(tidy_df, min_shared_cells = 100, min_percent_cells = None, max_cluster_size = sys.maxsize):
    '''
        Input: tidy_df, a df with columns of ['symbol','barcode','gene name', 'class']
        Output: Writes multiple files of format "candidates_k.pickle" and "cell_dict_k.pickle" for k in 1:max_cluster_size
            
        Assumptions:    barcode corresponds to cell_id
                        gene_names <-> symbol (1:1 relationship)
    '''
    total_cells = tidy_df['barcode'].nunique()
    
    if(min_percent_cells is not None):
        min_shared_cells = int(min_percent_cells * total_cells)
    
    # get the counts for each gene
    candidate_cell_dict = tidy_df.groupby('symbol')
    counts = candidate_cell_dict.count()['barcode'].to_dict()
    candidate_cell_dict = candidate_cell_dict['barcode'].apply(frozenset)
    
    # store the symbols
    next_candidates = counts.keys()

    # filter stuff -- first iteration!
    candidate_filter = list(map(lambda x: x > min_shared_cells, counts.values()))
    candidates = list(map(frozenset, map(lambda x: (x,), it.compress(next_candidates, candidate_filter))))
    counts = list(it.compress(counts.values(), candidate_filter))
    
    # clean the groups dict to get candidate cell sets for each input -- may be slow!
    # print(frozenset(candidate_cell_dict.groups[next(iter(next(iter(candidates))))][0]),)
    candidate_cell_dict = dict(zip(candidates,[candidate_cell_dict[next(iter(x))] for x in candidates]))
    
    # 0 expected counts for first iteration
    expected_counts = [0] * len(candidate_cell_dict)
    
    # store our first entry!
    size_k=1
    pickle_candidates(candidates, counts, expected_counts, candidate_cell_dict, size_k)
    
    n = len(candidate_cell_dict)
    print(str(n) + ' first candidates!')
    
    # loop!
    while(n > 0 and size_k < max_cluster_size):
        size_k += 1
        
        total_combos = pow(n,2) - n 
        
        if(total_combos > pow(4096, 2)): # see if there would be > ~17 million combos
            # break into chunks of ~ 4 million for evaluation
            m = total_combos // pow(2048, 2)
            carry_over = total_combos % pow(2048, 2)
            
            total_done = 0
            ell = (m if carry_over == 0 else m + 1)
            results = []
            
            valid_candidates_iter = it.combinations(candidates, 2)
            print('created iter of size ' + str(total_combos))
            stack = 0
   
            next_candidates = filter(None, [None])
            
            p = mp.Pool(processes = max(mp.cpu_count()-1,1))
            
            big_helper = ft.partial(giant_filter, size_k = size_k, last_cell_dict = candidate_cell_dict, min_shared_cells = min_shared_cells)
            
            while total_done < ell - 1:
                next_candidates = it.chain(next_candidates, filter(lambda x: x is not None, p.map(big_helper, it.islice(valid_candidates_iter, pow(2048,2)))))
                #print('iter chunk: ' + str(total_done) + ' of ' + str(ell-1))
                total_done += 1
                
            next_candidates = it.chain(next_candidates, filter(lambda x: x is not None, p.map(big_helper, valid_candidates_iter)))
            total_done += 1
            
            #print('generated next candidates!')
            
            big_cander = ft.partial(giant_candidater, last_cell_dict=candidate_cell_dict, total_cells=total_cells)
            
            chainer = iter([])
            
            while(total_done > 0):
                chainer = it.chain(chainer, iter(p.map(big_cander, next_candidates)))
                #print('remaining chunk: ' + str(total_done) + ' of ' + str(ell))
                total_done -=  1
            
            try:
                candidates, counts, expected_counts, candidate_cell_dict = zip(*chainer)
            except:
                print('Ran out of candidates here!')
                break
            candidate_cell_dict = {x:candidate_cell_dict[x] for x[0] in candidate_cell_dict}
            
            p.close()
        else:
            ### this is the easy logic section in case we have fewer combinations. Much easier to read
            next_candidates = set(get_next_valid_candidates(candidates, n, size_k))
            print('Generated next valid candidates! Size: ' + str(size_k))
            print(len(next_candidates))
            print(next(iter(next_candidates)))
            candidates, counts, expected_counts, candidate_cell_dict = generate_new_candidates(next_candidates, candidate_cell_dict, min_shared_cells, total_cells)
        
        # calculate the new length of cells for the next iteration
        n = len(candidate_cell_dict)
        print('Evaluated '+ str(n) +' valid candidates!')
            
        # store our candidates at each stage. This allows us to reduce our RAM usage
        if(n > 0):
            pickle_candidates(candidates, counts, expected_counts, candidate_cell_dict, size_k)

In [4]:
dat = pd.read_csv('./cord_blood_kinases.csv', sep=',', header=0, index_col=0)

  mask |= (ar1 == a)


In [5]:
len(set(dat[dat['symbol']=='NRBP1']['barcode'].values) & set(dat[dat['symbol']=='LYN']['barcode'].values))

3455

    candidate_cell_dict = tidy_df.groupby('symbol')
    counts = candidate_cell_dict.count()['barcode'].to_dict()
    candidate_cell_dict = candidate_cell_dict['barcode'].apply(frozenset)

In [6]:
dat.groupby('symbol')['barcode'].count().to_dict()['NRBP1']

32942

In [7]:
dat.head()

Unnamed: 0,barcode,symbol
1,MantonCB1_HiSeq_1-AAACCTGCACAGACAG-1,STK4
2,MantonCB1_HiSeq_1-AAACCTGCACAGACAG-1,RIOK3
3,MantonCB1_HiSeq_1-AAACCTGCACAGACAG-1,PDK2
4,MantonCB1_HiSeq_1-AAACCTGCACAGACAG-1,MAPK7
5,MantonCB1_HiSeq_1-AAACCTGCACAGACAG-1,CDK10


In [8]:
print(dat.shape)

(5531681, 2)


In [None]:
s = time.time()
gather_gene_sets(dat, min_percent_cells=.01)
t = time.time()

print(str(t-s) + ' seconds')

303 first candidates!
Generated next valid candidates! Size: 2
45753
(frozenset({'HCK'}), frozenset({'RIPK1'}))
50.04134488105774 seconds to count candidates
0.012686014175415039 seconds to compress
47.803678035736084 seconds to count expected values
Evaluated 5182 valid candidates!
created iter of size 26847942
