In [1]:
import axs
import numpy as np

from astropy.io import fits
import astropy.coordinates as coord
import astropy.units as u

import pandas as pd
import matplotlib.pyplot as plt

import pyspark.sql.functions as sparkfunc
import pyspark.sql.types as pyspark_types
import numpy as np

In [2]:
%matplotlib notebook

# Setup spark

In [4]:
def spark_start(local_dir):
    from pyspark.sql import SparkSession
    
    spark = (
            SparkSession.builder
            .appName("LSD2")
            .config("spark.sql.warehouse.dir", local_dir)
            .config('spark.master', "local[6]")
            .config('spark.driver.memory', '8G') # 128
            .config('spark.local.dir', local_dir)
            .config('spark.memory.offHeap.enabled', 'true')
            .config('spark.memory.offHeap.size', '4G') # 256
            .config("spark.sql.execution.arrow.enabled", "true")
            .config("spark.driver.maxResultSize", "6G")
            .config("spark.driver.extraJavaOptions", f"-Dderby.system.home={local_dir}")
            .enableHiveSupport()
            .getOrCreate()
                    )   

    return spark

spark_session = spark_start("/epyc/users/kyboone/spark-tmp/")

catalog = axs.AxsCatalog(spark_session)


In [5]:
spark_session

Hovering over "Spark UI" above gives you the port number of the Spark web dashboard.  Epyc doesn't have that port open, though, so we use an SSH tunnel to forward the ports.  I like to put the following function into my `.bashrc` o my local machine:


```
function spark_tunnel()
{
        # this function takes one argument: the epyc port to tunnel
        # the ordering is backwards (requiring a manual refresh) because
        # I want to be able to manually kill the ssh tunnel
        open http://localhost:${1}/
        ssh -N -L ${1}:127.0.0.1:${1} username@epyc.astro.washington.edu
}
```

What tables does AXS know about?

# Load ZTF data

In [6]:
ztf = catalog.load('ztf_mar19_all')

# Plotting

In [8]:
def plot_lightcurve(row):
    plt.figure(figsize=(8, 6))

    for filterid in np.unique(row['filterid']):
        cut = (
            (np.array(row['filterid']) == filterid)
            & (np.array(row['catflags']) == 0.)
        )
        def cc(x):
            return np.array(x)[cut]
        plt.errorbar(cc(np.array(row['mjd'])), cc(row['psfmag']), cc(row['psfmagerr']), fmt='o', c='C%d' % filterid, label='Filter %d' % filterid)

    plt.xlabel('mjd')
    plt.ylabel('Magnitude')
    plt.legend()
    plt.title('matchid %d' % row['matchid'])
    plt.gca().invert_yaxis()

# Dipper detection

In [11]:
from scipy.ndimage import minimum_filter1d

def detect_dippers(mjd, filterid, psfmag, psfmagerr, xpos, ypos, catflags, verbose=False, return_mjd=False):
    if len(mjd) == 0:
        return 0.

    order = np.argsort(mjd)

    # Throw out repeated measurements.
    ordered_mjd = np.array(mjd)[order]
    mask = np.abs(ordered_mjd - np.roll(ordered_mjd, 1)) > 1e-5

    mjd = np.array(mjd)[order][mask]
    filterid = np.array(filterid)[order][mask]
    psfmag = np.array(psfmag)[order][mask]
    psfmagerr = np.array(psfmagerr)[order][mask]
    xpos = np.array(xpos)[order][mask]
    ypos = np.array(ypos)[order][mask]
    catflags = np.array(catflags)[order][mask]

    scores = np.zeros(len(psfmag))
    
    pad_width = 20
    x_border = 3072
    y_border = 3080

    for iter_filterid in np.unique(filterid):
        cut = (
            (filterid == iter_filterid)
            & (xpos > pad_width)
            & (xpos < x_border - pad_width)
            & (ypos > pad_width)
            & (ypos < y_border - pad_width)
            & (catflags == 0)
        )

        if np.sum(cut) < 10:
            # Require at least 10 observations to have reasonable statistics.
            continue

        use_psfmag = psfmag[cut]
        use_psfmagerr = psfmagerr[cut]
        
        core_std = np.std(use_psfmag)
        filter_scores = (use_psfmag - np.median(use_psfmag)) / np.sqrt(core_std**2 + use_psfmagerr**2)

        scores[cut] = filter_scores
                
    # Check for sequential runs.
    
    # Get the minimum score for a run.
    filtered_scores = minimum_filter1d(scores, 4, mode='constant')
        
    result = float(np.max(filtered_scores))
    max_mjd = mjd[np.argmax(filtered_scores)]

    if verbose:
        print("Max mjd: ", max_mjd)

    if return_mjd:
        return result, max_mjd
    else:
        return result

def detect_dippers_row(row, verbose=False, return_mjd=False):
    return detect_dippers(row['mjd'], row['filterid'], row['psfmag'],
                          row['psfmagerr'], row['xpos'], row['ypos'],
                          row['catflags'],
                          verbose=verbose, return_mjd=return_mjd)

# Create a UDF for spark
detect_dippers_udf = sparkfunc.udf(detect_dippers, returnType=pyspark_types.FloatType())

# Run the spark query

## Run and save the query

In [15]:
%%time

# Run on spark
res = (
    ztf.region(ra1=296, ra2=302, dec1=8, dec2=15)
    .exclude_duplicates()
    .where(sparkfunc.col("nobs_avail") > 20)
    .select(
        '*',
        detect_dippers_udf(ztf['mjd'], ztf['filterid'], ztf['psfmag'], ztf['psfmagerr'], ztf['xpos'], ztf['ypos'], ztf['catflags']).alias('score')
    )
    .where(sparkfunc.col("score") > 2.)
    .collect()
)
print(len(res))

import pickle
pickle.dump(res, open('test_query.pkl', 'wb'))

4189
CPU times: user 1.79 s, sys: 614 ms, total: 2.4 s
Wall time: 40min 9s


## Load the query results

In [19]:
res = pickle.load(open('test_query.pkl', 'rb'))

# Analysis

In [25]:
# Order the light curves by their scores
scores = []
mjds = []

for i in res:
    score, mjd = detect_dippers_row(i, return_mjd=True)
    scores.append(score)
    mjds.append(mjd)

order = np.argsort(scores)[::-1]

# Apparently things break if you cast a list of spark objects into a numpy array,
# so keep everything as a list.
ordered_res = [res[i] for i in order]

In [26]:
for idx in range(5):
    print("=================")
    row = ordered_res[idx]
    print("idx: %d" % idx)
    print(detect_dippers_row(row, verbose=True))
    print("ra: %.6f" % row['ra'])
    print("dec: %.6f" % row['dec'])
    print("xpos: %.2f, ypos: %.2f" % (np.mean(row['xpos']), np.mean(row['ypos'])))

    plot_lightcurve(row)

idx: 0
Max mjd:  58473.0838542
6.859878602311454
ra: 299.379173
dec: 12.904513
xpos: 1651.64, ypos: 2444.82


<IPython.core.display.Javascript object>

idx: 1
Max mjd:  58468.0846875
6.473713207113447
ra: 296.428319
dec: 10.070698
xpos: 1836.72, ypos: 2608.24


<IPython.core.display.Javascript object>

idx: 2
Max mjd:  58471.0843056
6.348094801539755
ra: 299.740774
dec: 13.178672
xpos: 402.78, ypos: 1466.24


<IPython.core.display.Javascript object>

idx: 3
Max mjd:  58471.0843056
6.166498936036734
ra: 297.327360
dec: 11.605076
xpos: 2199.09, ypos: 251.59


<IPython.core.display.Javascript object>

idx: 4
Max mjd:  58469.0841088
6.165378361698936
ra: 296.206076
dec: 10.010648
xpos: 2616.73, ypos: 2814.42


<IPython.core.display.Javascript object>

# Artifact investigation

In [22]:
# Hmm, many of our dippers show up in a line on the sky!
ra = [i['ra'] for i in res]
dec = [i['dec'] for i in res]
xpos = [np.mean(i['xpos']) for i in res]

plt.figure(figsize=(8, 6))
plt.scatter(ra, dec, c = scores, vmin=2, vmax=4)
plt.xlabel('RA')
plt.ylabel('Dec')
plt.colorbar(label='Score')

<IPython.core.display.Javascript object>

<matplotlib.colorbar.Colorbar at 0x7fa4f98ac710>