In [4]:
import axs
import numpy as np
from astropy.io import fits
import matplotlib
import matplotlib.pyplot as plt

import os
from pyspark.sql.functions import size as spark_size
from matplotlib.backends.backend_pdf import PdfPages


matplotlib.rcParams['figure.dpi'] = 120

In [5]:
def spark_start(project_path, metastore=None):
    from pyspark.sql import SparkSession

    warehouse_location = os.path.join(project_path, 'spark-warehouse')

    local_dir = os.path.join(project_path, 'spark-tmp')

    spark = ( 
            SparkSession.builder
            .appName("LSD2")
            .config("spark.sql.warehouse.dir", warehouse_location)
            .config('spark.master', "local[4]")
            .config('spark.driver.memory', '6G') # 128
            .config('spark.local.dir', local_dir)
            .config('spark.memory.offHeap.enabled', 'true')
            .config('spark.memory.offHeap.size', '4G') # 256
            .config("spark.sql.execution.arrow.enabled", "true")
            .config("spark.driver.maxResultSize", "6G")
            .config("spark.driver.extraJavaOptions", f"-Dderby.system.home={metastore}")
            .enableHiveSupport()
            .getOrCreate()
                    )   

    return spark

spark_session = spark_start("/epyc/users/ctslater")

catalog = axs.AxsCatalog(spark_session)

In [6]:
ztf = catalog.load("ztf_1am_lc")

In [4]:
%%time

ztf.select("*").where((ztf['matchid'] == 37712032113057) & (ztf['zone'] == 4823)).collect()

CPU times: user 14.4 ms, sys: 8.51 ms, total: 22.9 ms
Wall time: 19.3 s


[Row(matchid='00037712032113057', ra=227.4292279, dec=-9.6026943, zone=4823, dup=0, astrometricrms=0.0, bestastrometricrms=0.0, bestchisq=0.0, bestcon=0.0, bestlineartrend=0.0, bestmagrms=0.0, bestmaxmag=0.0, bestmaxslope=0.0, bestmeanmag=0.0, bestmedianabsdev=0.0, bestmedianmag=0.0, bestminmag=0.0, bestnmedianbufferrange=0, bestnpairposslope=0, bestprobnonqso=0.0, bestprobqso=0.0, bestskewness=0.0, bestsmallkurtosis=0.0, beststetsonj=0.0, beststetsonk=0.0, bestvonneumannratio=0.0, bestweightedmagrms=0.0, bestweightedmeanmag=0.0, chisq=0.0, con=0.0, lineartrend=0.0, magrms=None, maxmag=20.48711585998535, maxslope=0.0, meanmag=20.48711585998535, medianabsdev=0.0, medianmag=20.48711585998535, minmag=20.48711585998535, nbestobs=0, ngoodobs=1, nmedianbufferrange=0, nobs=1, npairposslope=0, probnonqso=0.0, probqso=0.0, refchi=0.0, refmag=0.0, refmagerr=0.0, refsharp=0.0, refsnr=0.0, skewness=0.0, smallkurtosis=0.0, stetsonj=0.0, stetsonk=0.0, uncalibmeanmag=0.0, vonneumannratio=0.0, weighte

In [7]:
help(ztf.crossmatch)

Help on method crossmatch in module axs.axsframe:

crossmatch(axsframe, r=0.0002777777777777778) method of axs.axsframe.AxsFrame instance
    Performs the cross-match operation between this AxsFrame and `axsframe`, which can be either an AxsFrame or
    a Spark's DataFrame, using `r` for the cross-matching radius (one arc-second by default).
    
    Both frames need to have `zone`, `ra`, `dec`, and `dup` columns.
    
    Bote that if `axsframe` is a Spark frame, the cross-match operation will not be optimized and might take
    a much longer time to complete.
    
    The best performance can be expected when both tables are read directly as AxsFrames.
    In that scenario cross-matching will be done on bucket pairs in parallel without data movement between
    executors. If, however, one of the two AxsFrames being cross-matched is the result of a `groupBy` operation,
    for example, data movement cannot be avoided. In those cases, it might prove faster to first save the
    "groupe

In [7]:
sesar_axs = catalog.load("sesar_rrlyrae")

In [8]:
%%time
matched = sesar_axs.crossmatch(ztf).drop("axsdist")
results = (matched.select("ra", "dec", "matchid", "Per", "weightedmeanmag", "filterid", "mjd", "psfflux")
                  .where((spark_size(matched['mjd']) > 5) &
                         ( (matched['S3ab'] > 0.8) | (matched['S3c'] > 0.8))).head(200))

CPU times: user 54.4 ms, sys: 17.2 ms, total: 71.6 ms
Wall time: 1min 26s


In [None]:
%%time

# This takes a while
#matched.where(spark_size(matched['mjd']) > 5).count()

In [10]:
results[0]

Row(ra=265.75893, dec=19.82638, matchid='0005871103101088', Per=0.57479179, weightedmeanmag=15.913511276245117, filterid=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], mjd=[58287.248669, 58307.316331, 58290.2340625, 58290.4033102, 58308.2273495, 58311.3370023, 58308.2084838, 58292.215706, 58287.3373148, 58307.2334028, 58286.4210417, 58303.4062269, 58306.2952083, 58276.3783333, 58312.3361806, 58288.3723495, 58310.3515509, 58234.3367477, 58312.3602662, 58288.379294, 58308.3797685, 58204.4697338, 58205.5278935, 58206.5105324, 58207.4672222, 58237.3199306, 58292.416331, 58207.4278588

In [17]:
import cesium
from cesium.time_series import TimeSeries
from cesium.featurize import featurize_single_ts

In [12]:
result = results[0]

In [19]:
%%time
result = results[0]

features_to_use = ["amplitude",
                   "percent_beyond_1_std",
                   "maximum",
                   "max_slope",
                   "median",
                   "median_absolute_deviation",
                   "percent_close_to_median",
                   "minimum",
                   "skew",
                   "std",
                   "weighted_average"]
ls_features = ["freq1_amplitude1",
                "freq1_amplitude2",
                "freq1_amplitude3",
                "freq1_amplitude4",
                "freq1_freq",
                "freq1_lambda",
                "freq1_rel_phase2",
                "freq1_rel_phase3",
                "freq1_rel_phase4",
                "freq1_signif",
                "freq2_amplitude1",
                "freq2_amplitude2",
                "freq2_amplitude3",
                "freq2_amplitude4",
                "freq2_freq",
                "freq2_rel_phase2",
                "freq2_rel_phase3",
                "freq2_rel_phase4"]

ts = TimeSeries(t=np.array(result['mjd']), m=np.array(result['psfflux']))
feat_out = featurize_single_ts(ts, features_to_use + ls_features)

CPU times: user 1.36 s, sys: 145 ms, total: 1.5 s
Wall time: 51.6 ms


In [21]:
1/feat_out['freq1_freq']

channel
0    0.574923
dtype: float64

In [22]:
result['Per']

0.57479179

In [20]:
feat_out

feature                    channel
amplitude                  0            9396.832794
percent_beyond_1_std       0               0.307190
maximum                    0           19601.496094
max_slope                  0          825072.546836
median                     0           12735.896484
median_absolute_deviation  0            2143.480469
percent_close_to_median    0               0.431373
minimum                    0             807.830505
skew                       0              -0.375430
std                        0            3359.907709
weighted_average           0           13263.989591
freq1_amplitude1           0            3022.199473
freq1_amplitude2           0            1063.568925
freq1_amplitude3           0             230.753765
freq1_amplitude4           0              16.869063
freq1_freq                 0               1.739364
freq1_lambda               0               7.425334
freq1_rel_phase2           0              -0.952610
freq1_rel_phase3           0 

Big book of lightcurves
---------

In [54]:
%%capture
# Supress all output from this. This is a hack.


figures = []
for n in range(len(results)//4):

    fig, axes = plt.subplots(2, 2)
    for m, ax in enumerate(axes.flatten()):
        result_id = 4*n + m
        this_source = results[result_id]

        sel, = np.where(np.array(this_source['filterid']) == 1)
        if(len(sel) > 0):
            ax.plot(np.array(this_source['mjd'])[sel]/this_source["Per"] % 1,
                     np.array(this_source['psfflux'])[sel], '.')

        sel, = np.where(np.array(this_source['filterid']) == 2)
        if(len(sel) > 0):
            ax.plot(np.array(this_source['mjd'])[sel]/this_source["Per"] % 1,
                     np.array(this_source['psfflux'])[sel], '.')


        flux_std = np.std(this_source['psfflux'])
        flux_mean = np.mean(this_source['psfflux'])
        ax.set_ylim(flux_mean - 3*flux_std, flux_mean + 3*flux_std)
        ax.set_xlim(0, 1)
        ax.set_xlabel("Phase")
        ax.set_ylabel("Flux")
        ax.yaxis.set_ticklabels("")
        ax.xaxis.set_ticklabels(["0", "", "", "", "", "1"])

        ax.text(0.85, 0.95, "{:0.1f}".format(this_source['weightedmeanmag']),
                 fontsize=8, verticalalignment="top",
                 transform=ax.transAxes)
        ax.text(0.05, 0.95, "{:d}".format(result_id),
                 fontsize=8, verticalalignment="top",
                 transform=ax.transAxes)
        
    figures.append(fig)

with PdfPages('ztf_lyrae.pdf') as pdf:            
    for fig in figures:
        pdf.savefig(fig);

Handling Duplicates
==========

The following is an experiment

In [80]:
[(x['ra'], x['dec'], x['matchid']) for x in results[28:32]]

[(284.0499, 51.40418, '00079801032010593'),
 (284.0499, 51.40418, '0007980103107527'),
 (259.79441, 56.32541, '0007961001102429'),
 (259.79441, 56.32541, '0007961001205113')]

In [None]:
%%time

matched = sesar_axs.crossmatch(ztf).drop("axsdist")
renames = []
for col in matched.columns:
    if col in ('zone', 'ra', 'dec', 'dup'):
        renames.append(col)
    else:
        renames.append("a_" + col)
                
self_matched = matched.crossmatch(matched.toDF(*renames))
self_matched.select("ra", "dec", "matchid", "a_matchid").head(5)

