In [1]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import glob
import io

import boto3
import pandas
from scipy.spatial import cKDTree
from esutil.htm import HTM
from astropy.io import fits
from astropy.table import Table

from pyspark.sql.functions import udf
from pyspark.sql.types import LongType
from pyspark.sql import Row

plt.rc('figure', dpi=120)

In [2]:
# Don't run either if we're on AWS; sc is pre-existing.
if False:
    from pyspark import SparkContext, SparkConf
    from pyspark.sql import SparkSession
    from pyspark.sql import Row

    sc = SparkContext('local[*]')
    sc.setLogLevel("WARN")
else:
    from pysparkling import Context
    sc = Context(max_retries=1)


ImportError: No module named pysparkling

Dataframe Version
=======

In [3]:
s3 = boto3.resource('s3')
ptf_bucket = s3.Bucket("palomar-transient-factory")


In [4]:
bucket_keys = [x.key for x in ptf_bucket.objects.filter(Prefix="input_parquet2/")]
catalog_keys = filter(lambda x: x.endswith("parquet"), bucket_keys)


In [5]:
%%time
catalog2 = spark.read.parquet(*["s3://palomar-transient-factory/" + key for key in catalog_keys])


CPU times: user 5.92 s, sys: 420 ms, total: 6.34 s
Wall time: 12min 5s


In [None]:
%%time

def htm_udf_func(ra, dec):
    htm_obj = HTM(depth=8)
    return htm_obj.lookup_id(ra, dec).item()
    
htm_udf = udf(htm_udf_func, LongType())

def match_sources_sparksql(input_tuple):
    htm_id, rows = input_tuple
    
    ra_field, dec_field = "ALPHAWIN_J2000", "DELTAWIN_J2000"
    ra = np.array([x[ra_field] for x in rows])
    dec = np.array([x[dec_field] for x in rows])

    tree = cKDTree(np.stack((ra, dec), axis=1))

    candidate_groups = []
    already_matched_ids = set()
    matched_obj_counts = []
    # First pass, just to get groupings
    for this_id in range(len(ra)):
        if this_id in already_matched_ids:
            continue
        dists, idx = tree.query( (ra[this_id], dec[this_id]), k=15, distance_upper_bound=4/3600.0)
        sel, = np.where((dists < 3/3600.0) & (idx != this_id))
        
        valid_match_ids = set(idx[sel]) - already_matched_ids
        matched_obj_counts.append(len(valid_match_ids))
        
        if len(valid_match_ids) == 0:
            continue

        match_arr = np.array(list(valid_match_ids))
        mean_ra = np.mean(ra[match_arr])
        mean_dec = np.mean(dec[match_arr])
        candidate_groups.append((mean_ra, mean_dec))
        already_matched_ids.update(valid_match_ids)

    if(len(candidate_groups) == 0):
        return [Row(obj_id=0, **this_row.asDict()) for this_row in rows]
        
    # Now we match all the sources to their closest object, if it's within bounds
    reverse_tree = cKDTree(np.stack(candidate_groups))
    dists, idx = reverse_tree.query(np.stack((ra, dec), axis=1), distance_upper_bound=3/3600.0)
    
    obj_ids = htm_id*100000 + idx
    return [Row(obj_id=obj_id.item(), **this_row.asDict()) for (obj_id, this_row) in zip(obj_ids, rows)]


grouped_cells = catalog2.withColumn("htm_id", htm_udf(catalog2.ALPHAWIN_J2000,
                                                      catalog2.DELTAWIN_J2000)).rdd.groupBy(lambda x: x['htm_id'])
new_rows = grouped_cells.flatMap(match_sources_sparksql)
new_df = spark.createDataFrame(new_rows)
new_df.write.save("s3://palomar-transient-factory/ptf_sources.parquet", format="parquet", partitionBy="htm_id")

Debugging
======

In [33]:
%%time
catalog2.count()

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 2.86 s


368244

In [41]:
%%time
cat_with_htm = catalog2.withColumn("htm_id", htm_udf(catalog2.ALPHAWIN_J2000,
                                                     catalog2.DELTAWIN_J2000)).limit(5).toPandas()

CPU times: user 16 ms, sys: 0 ns, total: 16 ms
Wall time: 4.18 s


In [None]:
lengths = catalog2.withColumn("htm_id", htm_udf(catalog2.ALPHAWIN_J2000,
                                                catalog2.DELTAWIN_J2000)).rdd.groupBy(lambda x: x['htm_id']).map(lambda x: len(x[1])).collect()


In [83]:
new_df.where(new_df.obj_id > 0).count()

366333

In [84]:
new_df.where(new_df.obj_id > 0).limit(10).toPandas()

Unnamed: 0,ALPHAWIN_J2000,AWIN_IMAGE,AWIN_WORLD,A_IMAGE,A_WORLD,BACKGROUND,BWIN_IMAGE,BWIN_WORLD,B_IMAGE,B_WORLD,...,X_WORLD,Y2WIN_IMAGE,Y2_IMAGE,YPEAK_IMAGE,YWIN_IMAGE,Y_IMAGE,Y_WORLD,ZEROPOINT,htm_id,obj_id
0,80.125276,0.689859,0.000193,1.302463,0.000365,13878.271484,0.658805,0.000185,1.195471,0.000336,...,80.125306,0.438457,1.450896,1767,1766.954966,1766.917236,1.398516,26.024088,983560,98356000000
1,80.130994,0.66673,0.000187,1.20405,0.000338,13878.270508,0.642135,0.00018,1.096167,0.000308,...,80.131027,0.418254,1.305607,1772,1772.037238,1771.920776,1.397117,26.024166,983560,98356000001
2,80.135859,0.725613,0.000204,0.798849,0.000224,13877.295898,0.551626,0.000155,0.498856,0.00014,...,80.135843,0.522567,0.637102,1774,1774.185323,1774.038208,1.396528,26.024221,983560,98356000002
3,80.137665,0.668784,0.000188,1.249107,0.00035,13875.850586,0.644145,0.000181,1.161348,0.000326,...,80.13768,0.419704,1.388179,1800,1799.759088,1799.692505,1.389324,26.024328,983560,98356000003
4,80.112051,0.684169,0.000192,0.961605,0.00027,13878.875977,0.628712,0.000177,0.830252,0.000233,...,80.112053,0.423764,0.809199,1806,1806.33169,1806.340332,1.387427,26.023792,983560,98356000004
5,80.127246,0.680116,0.000191,1.272174,0.000357,13873.689453,0.649231,0.000182,1.178229,0.000331,...,80.127271,0.429283,1.486101,1820,1819.681469,1819.611206,1.383717,26.023701,983560,98356000005
6,80.135457,0.699253,0.000196,1.018012,0.000285,13871.704102,0.650654,0.000183,0.903249,0.000254,...,80.13545,0.446193,0.823048,1823,1822.496195,1822.494263,1.382917,26.023895,983560,98356000006
7,80.102132,0.618011,0.000173,0.99078,0.000278,13885.042969,0.577064,0.000162,0.896201,0.000252,...,80.102141,0.344462,0.836559,1843,1843.245331,1843.295532,1.377036,26.02343,983560,98356000007
8,80.129837,0.769297,0.000216,1.363832,0.000382,13877.917969,0.727673,0.000204,1.26083,0.000354,...,80.12987,0.544227,1.660013,1902,1902.468121,1902.426392,1.360459,26.022034,983560,98356000008
9,80.088185,0.699395,0.000196,1.50326,0.000422,13883.77832,0.65346,0.000184,0.934541,0.000262,...,80.088252,0.438646,1.190953,1917,1917.064905,1916.914429,1.356343,26.022785,983560,98356000009
