In [1]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import glob
import io

import boto3
import pandas
from scipy.spatial import cKDTree
from esutil.htm import HTM
from astropy.io import fits
from astropy.table import Table

plt.rc('figure', dpi=120)

In [None]:
# Don't run either if we're on AWS; sc is pre-existing.
if False:
    from pyspark import SparkContext, SparkConf
    from pyspark.sql import SparkSession
    from pyspark.sql import Row

    sc = SparkContext('local[*]')
    sc.setLogLevel("WARN")
else:
    from pysparkling import Context
    sc = Context(max_retries=1)


In [4]:
ra_field, dec_field = "ALPHAWIN_J2000", "DELTAWIN_J2000"

Pandas version
-----------

In [101]:

# This is a critical performance setting, since otherwise
# pandas will manually call the garbage collector way too much
# and that will dominate the runtime.
pandas.set_option('mode.chained_assignment', None)

def readPTFFile_pandas(filename):
    table = Table.read(filename)
    table.remove_columns(('MAG_APER', 'MAGERR_APER', 'FLUX_APER', 'FLUXERR_APER', 'FLUX_RADIUS'))
    df = table.to_pandas()
    return df

def readPTFFile_pandas_s3(key):
    s3 = boto3.resource('s3')
    ptf_bucket = s3.Bucket("palomar-transient-factory")

    with io.BytesIO() as f:
        ret = ptf_bucket.download_fileobj(key, f)
        f.seek(0)
        ptf_table = Table.read(f, format="fits")
    ptf_table.remove_columns(('MAG_APER', 'MAGERR_APER', 'FLUX_APER', 'FLUXERR_APER', 'FLUX_RADIUS'))
    df = ptf_table.to_pandas()
    return df

def split_table_by_htmid(table):
    # Return a set of tuples [ (htm_id, Table), ... ]
    htm_obj = HTM(depth=6)
    ra_field, dec_field = "ALPHAWIN_J2000", "DELTAWIN_J2000"
    
    ra = table[ra_field]
    dec = table[dec_field]
    htm_id = htm_obj.lookup_id(ra, dec)
    table['htm_id'] = htm_id
    htm_groups = table.groupby(htm_id)
    return htm_groups

def match_sources_pandas(input_tuple):
    htm_id, table = input_tuple
    
    ra_field, dec_field = "ALPHAWIN_J2000", "DELTAWIN_J2000"    
    ra = table[ra_field]
    dec = table[dec_field]

    tree = cKDTree(np.stack((np.array(ra), np.array(dec)), axis=1))
    print("Len ", len(ra))

    candidate_groups = []
    already_matched_ids = set()
    matched_obj_counts = []
    # First pass, just to get groupings
    for this_id in range(len(ra)):
        if this_id in already_matched_ids:
            continue
        dists, idx = tree.query( (ra.iloc[this_id], dec.iloc[this_id]), k=15, distance_upper_bound=4/3600.0)
        sel, = np.where((dists < 3/3600.0) & (idx != this_id))
        
        valid_match_ids = set(idx[sel]) - already_matched_ids
        matched_obj_counts.append(len(valid_match_ids))
        
        if len(valid_match_ids) == 0:
            continue
        candidate_table = table.iloc[list(valid_match_ids)]
        mean_ra = np.mean(candidate_table[ra_field])
        mean_dec = np.mean(candidate_table[dec_field])   
        candidate_groups.append((mean_ra, mean_dec))
        already_matched_ids.update(valid_match_ids)

        
    if(len(candidate_groups) == 0):
        table['obj_id'] = 0
        return table
        
    # Now we match all the sources to their closest object, if it's within bounds
    reverse_tree = cKDTree(np.stack(candidate_groups))
    dists, idx = reverse_tree.query(np.stack((np.array(ra), np.array(dec)), axis=1), distance_upper_bound=3/3600.0)
    
    table['obj_id'] = htm_id*100000 + idx

    return table

zeroValue = None
def seqFunc(a, b):
    if a is None:
        return b
    if b is None:
        return a
    return a.append(b)

In [8]:
s3 = boto3.resource('s3')
ptf_bucket = s3.Bucket("palomar-transient-factory")


In [106]:
%%time

bucket_keys = [x.key for x in ptf_bucket.objects.filter(Prefix="input_catalogs/").limit(300)]
catalog_keys = filter(lambda x: x.endswith("ctlg"), bucket_keys)

split_records = sc.parallelize(catalog_keys).map(readPTFFile_pandas_s3).flatMap(split_table_by_htmid)

grouped_records = split_records.aggregateByKey(zeroValue, seqFunc, seqFunc)
print("Number of HTM cells: ", grouped_records.count())

matched_records = grouped_records.map(match_sources_pandas).collect()

('Number of HTM cells: ', 32)
CPU times: user 500 ms, sys: 488 ms, total: 988 ms
Wall time: 1min 21s


In [110]:
len(matched_records[6]['obj_id']), len(np.unique(matched_records[6]['obj_id']))

(23092, 3826)

In [111]:
matched_records[6]

Unnamed: 0,NUMBER,FLAGS,XWIN_IMAGE,YWIN_IMAGE,X_WORLD,Y_WORLD,XPEAK_IMAGE,YPEAK_IMAGE,ERRTHETAWIN_IMAGE,DELTAWIN_J2000,...,ERRX2_IMAGE,ERRY2_IMAGE,ERRXY_IMAGE,AWIN_IMAGE,BWIN_IMAGE,FLUX_PETRO,FLUXERR_PETRO,ZEROPOINT,htm_id,obj_id
2,3,0,1951.583884,9.330748,83.444823,1.899086,1952,9,65.479416,1.899076,...,0.001854,0.001934,1.053254e-04,0.719189,0.689573,23444.683594,1149.866577,26.900555,61477,6147700000
5,6,24,2046.630441,9.457185,83.471519,1.899223,2047,9,-62.619141,1.899113,...,0.003962,0.009179,-2.112649e-03,1.215652,0.964346,13807.681641,1463.676147,26.899826,61477,6147700001
6,7,0,1913.531346,9.218420,83.434187,1.899098,1914,9,55.347759,1.899077,...,0.002433,0.002101,-2.012397e-05,0.744372,0.667771,19411.265625,935.117859,26.900846,61477,6147700002
11,12,17,1980.209315,3.640525,83.452894,1.900588,1980,4,76.863327,1.900697,...,0.009517,0.010738,-6.978014e-04,0.856739,0.762328,8130.187500,1037.086304,26.900425,61477,6147700003
17,18,17,1987.183937,2.767323,83.454824,1.900784,1987,3,2.800322,1.900948,...,0.023895,0.010336,5.105952e-03,0.914331,0.828392,5491.603516,1071.114136,26.900385,61477,6147700004
19,20,24,1935.033239,1.484919,83.440693,1.901355,1939,1,-0.058739,1.901267,...,0.290314,0.002341,-6.964515e-03,2.828972,0.872522,8589.767578,983.322815,26.900810,61477,6147703830
24,25,0,1942.573842,22.292838,83.442324,1.895440,1943,22,71.373009,1.895428,...,0.001773,0.002055,1.742621e-04,0.751059,0.697347,22545.015625,966.854980,26.900410,61477,6147700005
29,30,0,1992.330951,20.462941,83.456246,1.896006,1992,20,52.892029,1.895981,...,0.030341,0.031866,2.123634e-02,1.400311,0.816164,6344.582031,1848.789185,26.900068,61477,6147700006
33,34,0,1985.115742,24.674069,83.454245,1.894793,1985,25,-86.729156,1.894793,...,0.006083,0.011026,-1.441188e-03,0.696001,0.617740,4856.511719,837.035950,26.900055,61477,6147700007
46,47,0,1936.405636,34.442798,83.440594,1.892027,1936,34,58.102726,1.892010,...,0.010604,0.010595,1.688994e-03,0.746125,0.619016,4689.101562,1291.679688,26.900255,61477,6147700008


In [51]:
len(matched_records[4]['obj_id']), len(np.unique(matched_records[4]['obj_id']))

(3252, 394)