In [1]:
import sys
import os
import gc
import itertools as it
import tempfile
from collections import Counter, namedtuple
import datetime as dt
import multiprocessing as mp

import numpy as np
import pandas as pd

from joblib import Parallel, delayed, cpu_count, load, dump, Memory

%matplotlib inline
import matplotlib.pyplot as plt

In [2]:
%%capture
# corral startup
sys.path.insert(1, "/home/jbcabral/carpyncho3")
os.environ.setdefault("CORRAL_SETTINGS_MODULE", "carpyncho.settings")

from corral import core
core.setup_environment()

from corral import db
from carpyncho.models import LightCurves, PawprintStack, PawprintStackXTile

In [3]:
%%capture capt

with db.session_scope() as session:
    lc = session.query(LightCurves).filter(LightCurves.tile.has(name="b396")).one()
    obs_counter = lc.obs_counter
    observations = lc.observations
    pxts = session.query(PawprintStackXTile).filter(PawprintStackXTile.tile == lc.tile).all()
    mjds = {pxt.pawprint_stack.id: pxt.pawprint_stack.mjd for pxt in pxts}

In [4]:
obs_min = np.random.choice(
    obs_counter[obs_counter["cnt"] >= 87]["id"], 10000, False)

df = pd.DataFrame(observations[['bm_src_id', u'pwp_id']])
df = df[df.bm_src_id.isin(obs_min)]

In [5]:
df["mjd"] = df.pwp_id.apply(lambda pwp_id: mjds[pwp_id])

%time groups = df.groupby("bm_src_id")

mean_mjds = dict(groups.mjd.mean())

CPU times: user 1e+03 µs, sys: 0 ns, total: 1e+03 µs
Wall time: 888 µs


In [6]:
temp_folder = "./cache"
filename = os.path.join(temp_folder, 'groups.mmap')
dump(groups, filename)
groups = load(filename, mmap_mode='r+')

In [12]:
%time combinations = tuple(it.combinations(obs_min, 2))

CPU times: user 1min 50s, sys: 4.32 s, total: 1min 54s
Wall time: 1min 54s


In [65]:
%time chunks = np.array_split(combinations, 10)

CPU times: user 42.9 s, sys: 527 ms, total: 43.4 s
Wall time: 47.7 s


In [66]:
Best = namedtuple("Best", ["src_1", "src_2", "n_1", "n_2", "int_12", "mmjd_1", "mmjd_2", "mmjd_diff"])

class GetNBG(mp.Process):
    
    def __init__(self, idx, total, chunk, groups, means):
        super(GetNBG, self).__init__()
        self.idx = idx
        self.total = total
        self.chunk = chunk
        self.groups = groups
        self.means = means
        self.size = len(chunk)
        self._cache = {}
        self.queue = mp.Queue()
        
    def get_group(self, k):
        if k not in self._cache:
            self._cache[k] = frozenset(self.groups.get_group(k)["pwp_id"].values)
        return self._cache[k]
    
    def is_best_nbg_candidate(self, candidate, nbgs):
        for k in (candidate.src_1, candidate.src_2):
            # if the key is not in the buffer
            # continue comparing the next one
            if k not in nbgs:
                continue
            other = nbgs[k]
            
            # if the candidate is exactly
            # equal to the already stored
            # or with less intersection points
            # this candidate is bad
            if (candidate == other or 
                candidate.int_12 <= other.int_12):
                    return False
            
            # if the intersection is the same but the
            # dispersion is greater then also is bad
            if (candidate.int_12 == other.int_12 and
                candidate.mmjd_diff >= other.mmjd_diff):
                    return False
                
        # here the candidate is ok
        return True
        
    def is_best_no_nbg_candidate(self, candidate, no_nbgs):
        for k in (candidate.src_1, candidate.src_2):
            # if the key is not in the buffer
            # continue comparing the next one
            if k not in no_nbgs:
                continue
            other = no_nbgs[k]
            
            # if the candidate is exactly
            # equal to the already stored
            # or with more intersection points
            # this candidate is bad
            if (candidate == other or 
                candidate.int_12 > other.int_12):
                    return False
            
            # if the intersection is the same but the
            # dispersion is lower then also is bad
            if (candidate.int_12 == other.int_12 and
                candidate.mmjd_diff <= other.mmjd_diff):
                    return False
                
        # here the candidate is ok
        return True
        
    def run(self):
        start = dt.datetime.now()
        print("[{}] Starting {}/{} with {} sources".format(
            start, self.idx, self.total, self.size))
        nbgs, no_nbgs = {}, {}
        for k1, k2 in self.chunk:
            # extract the two groups to compare
            # and create the candidate object
            g1, g2 = self.get_group(k1), self.get_group(k2)
            int_12 = len(g1.intersection(g2))
            mmjd_1, mmjd_2 = self.means[k1], self.means[k2]
            candidate = Best(
                src_1=k1, src_2=k2, 
                n_1=len(g1), n_2=len(g2), int_12=int_12,
                mmjd_1=mmjd_1, mmjd_2=mmjd_2,
                mmjd_diff=np.abs(mmjd_1 - mmjd_2))
            
            if int_12 >= 50 and self.is_best_nbg_candidate(candidate, nbgs):
                self.clean_buff(candidate, nbgs)
                nbgs[k1] = candidate
                nbgs[k2] = candidate
            elif int_12 < 50 and self.is_best_no_nbg_candidate(candidate, no_nbgs):
                self.clean_buff(candidate, no_nbgs)
                no_nbgs[k1] = candidate
                no_nbgs[k2] = candidate
        
        nbgs = frozenset(nbgs.values())
        no_nbgs = frozenset(no_nbgs.values())
        self.queue.put((nbgs, no_nbgs))
        
        end = dt.datetime.now()
        print("[{}] Done {}/{}".format(end, self.idx, self.total))
        
    def clean_buff(self, candidate, buff):
        to_remove = set([candidate.src_1, candidate.src_2])
        for k in tuple(to_remove):
            if k in buff:
                other = buff[k]
                to_remove.update((other.src_1, other.src_2))
        for k in to_remove:
            if k in buff:
                del buff[k]
        
    def result(self):
        return queue.get()
    
total = cpu_count()
results_nbg, result_zeros = [], []
for cidx, chunk in enumerate(chunks):
    print("Starting chunk {}".format(cidx))
    procs = []
    for idx, for_cpu in enumerate(np.array_split(chunk, total)):
        proc = GetNBG(idx=idx, total=total, chunk=for_cpu, groups=groups, means=mean_mjds)
        proc.run()
        break
#         proc = NbgCalc(idx=idx, combs=for_cpu, total=total, groups=groups)
#         proc.start()
#         procs.append(proc)
#     for proc in procs:
#         proc.join()
#         presult = proc.result()
#         results.append(presult)
#     del procs
    break
    print("=" * 50)
    

Starting chunk 0
[2018-09-25 12:34:29.711058] Starting 0/48 with 104157 sources
[2018-09-25 12:37:22.006357] Done 0/48


In [67]:
r=proc.queue.get()
map(len,  r)

[11, 7]

In [64]:
Counter(for_cpu.ravel()).most_common()

[(33960000878767, 9999),
 (33960000326804, 418),
 (33960000782583, 2),
 (33960000651553, 2),
 (33960000717175, 2),
 (33960001012113, 2),
 (33960000487856, 2),
 (33960000324026, 2),
 (33960000848324, 2),
 (33960000356809, 2),
 (33960000356861, 2),
 (33960000193026, 2),
 (33960000815708, 2),
 (33960000979703, 2),
 (33960000946972, 2),
 (33960000914364, 2),
 (33960000717840, 2),
 (33960000423184, 2),
 (33960000587037, 2),
 (33960000554364, 2),
 (33960000226694, 2),
 (33960000301986, 2),
 (33960000390621, 2),
 (33960000325281, 2),
 (33960000849605, 2),
 (33960000030475, 2),
 (33960000809980, 2),
 (33960000686210, 2),
 (33960000302107, 2),
 (33960000096426, 2),
 (33960000325856, 2),
 (33960000751994, 2),
 (33960000194971, 2),
 (33960000555459, 2),
 (33960000752079, 2),
 (33960000915958, 2),
 (33960000719418, 2),
 (33960000064063, 2),
 (33960000326222, 2),
 (33960000228067, 2),
 (33960000392005, 2),
 (33960000097108, 2),
 (33960000752596, 2),
 (33960000064488, 2),
 (33960000097297, 2),
 (339