In [1]:
import sys
import os
import gc
import itertools as it
import tempfile
from collections import Counter, namedtuple
import datetime as dt
import multiprocessing as mp

import numpy as np
import pandas as pd

from joblib import Parallel, delayed, cpu_count, load, dump, Memory

%matplotlib inline
import matplotlib.pyplot as plt

In [2]:
%%capture
# corral startup
sys.path.insert(1, "/home/jbcabral/carpyncho3")
os.environ.setdefault("CORRAL_SETTINGS_MODULE", "carpyncho.settings")

from corral import core
core.setup_environment()

from corral import db
from carpyncho.models import LightCurves, PawprintStack, PawprintStackXTile

In [3]:
%%capture capt

with db.session_scope() as session:
    lc = session.query(LightCurves).filter(LightCurves.tile.has(name="b396")).one()
    obs_counter = lc.obs_counter
    observations = lc.observations
    pxts = session.query(PawprintStackXTile).filter(PawprintStackXTile.tile == lc.tile).all()
    mjds = {pxt.pawprint_stack.id: pxt.pawprint_stack.mjd for pxt in pxts}

In [4]:
obs_min = np.random.choice(
    obs_counter[obs_counter["cnt"] >= 87]["id"], 10000, False)

df = pd.DataFrame(observations[['bm_src_id', u'pwp_id']])
df = df[df.bm_src_id.isin(obs_min)]

In [5]:
df["mjd"] = df.pwp_id.apply(lambda pwp_id: mjds[pwp_id])

%time groups = df.groupby("bm_src_id")

mean_mjds = dict(groups.mjd.mean())

CPU times: user 1 ms, sys: 0 ns, total: 1 ms
Wall time: 735 µs


In [6]:
temp_folder = "./cache"
filename = os.path.join(temp_folder, 'groups.mmap')
dump(groups, filename)
groups = load(filename, mmap_mode='r+')

In [7]:
%time combinations = tuple(it.combinations(obs_min, 2))

CPU times: user 51.8 s, sys: 2.52 s, total: 54.3 s
Wall time: 54.3 s


In [8]:
%time chunks = np.array_split(combinations, 100)

CPU times: user 41.2 s, sys: 378 ms, total: 41.6 s
Wall time: 41.6 s


In [None]:
Best = namedtuple("Best", ["src_1", "src_2", "n_1", "n_2", "int_12", "mmjd_1", "mmjd_2", "mmjd_diff"])
fields = list(Best._fields)

def select_nbg(df):
    cleaned = []
    while len(df):
        selected = df.sort_values(["mmjd_diff", "int_12"]).iloc[0]
        rm_key = map(int, [selected.src_1, selected.src_2])
        df = df[~(df.src_1.isin(rm_key)) & ~(df.src_2.isin(rm_key))]
        cleaned.append(selected.to_dict())
    return pd.DataFrame(cleaned)[fields]

def select_no_nbg(df):
    cleaned = []
    while len(df):
        selected = df.sort_values(["mmjd_diff", "int_12"], ascending=False).iloc[0]
        rm_key = map(int, [selected.src_1, selected.src_2])
        df = df[~(df.src_1.isin(rm_key)) & ~(df.src_2.isin(rm_key))]
        cleaned.append(selected.to_dict())
    return pd.DataFrame(cleaned)[fields]

class GetNBG(mp.Process):
    
    def __init__(self, idx, total, chunk, groups, means):
        super(GetNBG, self).__init__()
        self.idx = idx
        self.total = total
        self.chunk = chunk
        self.groups = groups
        self.means = means
        self.size = len(chunk)
        self._cache = {}
        self.queue = mp.Queue()
        
    def get_group(self, k):
        if k not in self._cache:
            self._cache[k] = frozenset(self.groups.get_group(k)["pwp_id"].values)
        return self._cache[k]
        
    def run(self):
        start = dt.datetime.now()
        print("[{}] Starting {}/{} with {} sources".format(
            start, self.idx, self.total, self.size))
        nbgs, no_nbgs = [], []
        for k1, k2 in self.chunk:
            # extract the two groups to compare
            # and create the candidate object
            g1, g2 = self.get_group(k1), self.get_group(k2)
            int_12 = len(g1.intersection(g2))
            mmjd_1, mmjd_2 = self.means[k1], self.means[k2]
            candidate = Best(
                src_1=k1, src_2=k2, 
                n_1=len(g1), n_2=len(g2), int_12=int_12,
                mmjd_1=mmjd_1, mmjd_2=mmjd_2,
                mmjd_diff=np.abs(mmjd_1 - mmjd_2))
            
            if int_12 >= 50:
                nbgs.append(candidate)
            else:
                no_nbgs.append(candidate)
        
        nbgs = select_nbg(pd.DataFrame(nbgs))
        no_nbgs = select_no_nbg(pd.DataFrame(no_nbgs))
        self.queue.put((nbgs, no_nbgs))
        
        end = dt.datetime.now()
        print("[{}] Done {}/{}".format(end, self.idx, self.total))
        
    def result(self):
        return self.queue.get()
    
total = cpu_count()
nbgs, no_nbgs = None, None
for cidx, chunk in enumerate(chunks):
    print("Starting chunk {}/{}".format(cidx, len(chunks)))
    procs, chunk_nbgs, chunk_no_nbgs = [], None, None
    for idx, for_cpu in enumerate(np.array_split(chunk, total)):
        proc = GetNBG(idx=idx, total=total, chunk=for_cpu, groups=groups, means=mean_mjds)
        proc.start()
        procs.append(proc)
    del proc, for_cpu
    gc.collect()
    
    for proc in procs:
        proc.join()
        p_nbgs, p_no_nbgs = proc.result()
        if chunk_nbgs is None:
            chunk_nbgs, chunk_no_bgs = p_nbgs, p_no_nbgs
        else:
            chunk_nbgs = select_nbg(
                pd.concat([chunk_nbgs, p_nbgs]))
            chunk_no_nbgs = select_no_nbg(
                pd.concat([chunk_no_nbgs, p_no_nbgs]))
    
    if nbgs is None:
        nbgs, no_nbgs = chunk_nbgs, chunk_no_nbgs
    else:
        nbgs = select_nbg(pd.concat([nbgs, chunk_nbgs]))
        no_nbgs = select_no_nbg(pd.concat([no_nbgs, chunk_no_nbgs]))
    
    del procs, proc, p_nbgs, p_no_nbgs, chunk_nbgs, chunk_no_nbgs
    gc.collect()
    
    print("=" * 50)
    

Starting chunk 0/100
[2018-09-27 11:12:42.752507] Starting 0/48 with 10416 sources
[2018-09-27 11:12:42.752567] Starting 1/48 with 10416 sources
[2018-09-27 11:12:42.863309] Starting 2/48 with 10416 sources
[2018-09-27 11:12:43.002648] Starting 3/48 with 10416 sources
[2018-09-27 11:12:43.145290] Starting 4/48 with 10416 sources
[2018-09-27 11:12:43.302078] Starting 5/48 with 10416 sources
[2018-09-27 11:12:43.458108] Starting 6/48 with 10416 sources
[2018-09-27 11:12:43.620640] Starting 7/48 with 10416 sources
[2018-09-27 11:12:43.801045] Starting 8/48 with 10416 sources
[2018-09-27 11:12:44.019124] Starting 9/48 with 10416 sources
[2018-09-27 11:12:44.227534] Starting 10/48 with 10416 sources
[2018-09-27 11:12:44.447345] Starting 11/48 with 10416 sources
[2018-09-27 11:12:44.667506] Starting 12/48 with 10416 sources
[2018-09-27 11:12:44.905543] Starting 13/48 with 10416 sources
[2018-09-27 11:12:45.137440] Starting 14/48 with 10416 sources
[2018-09-27 11:12:45.373605] Starting 15/48 

In [None]:
nbgs.shape

In [None]:
no_nbgs.shape

In [None]:
nbgs.to_pickle("data/nbgs2.pkl")

In [None]:
no_nbgs.to_pickle("data/no_nbgs2.pkl")