In [1]:
import axs
import numpy as np

from astropy.io import fits
import astropy.coordinates as coord
import astropy.units as u

import pandas as pd
import matplotlib.pyplot as plt

import pyspark.sql.functions as sparkfunc
import pyspark.sql.types as pyspark_types
import numpy as np

import tqdm

In [19]:
%matplotlib inline

# Setup spark

In [3]:
def spark_start(local_dir):
    from pyspark.sql import SparkSession
    
    spark = (
            SparkSession.builder
            .appName("LSD2")
            .config("spark.sql.warehouse.dir", local_dir)
            .config('spark.master', "local[6]")
            .config('spark.driver.memory', '8G') # 128
            .config('spark.local.dir', local_dir)
            .config('spark.memory.offHeap.enabled', 'true')
            .config('spark.memory.offHeap.size', '4G') # 256
            .config("spark.sql.execution.arrow.enabled", "true")
            .config("spark.driver.maxResultSize", "6G")
            .config("spark.driver.extraJavaOptions", f"-Dderby.system.home={local_dir}")
            .enableHiveSupport()
            .getOrCreate()
                    )   

    return spark

spark_session = spark_start("/epyc/users/kyboone/spark-tmp/")

catalog = axs.AxsCatalog(spark_session)

In [4]:
spark_session

Hovering over "Spark UI" above gives you the port number of the Spark web dashboard.  Epyc doesn't have that port open, though, so we use an SSH tunnel to forward the ports.  I like to put the following function into my `.bashrc` o my local machine:


```
function spark_tunnel()
{
        # this function takes one argument: the epyc port to tunnel
        # the ordering is backwards (requiring a manual refresh) because
        # I want to be able to manually kill the ssh tunnel
        open http://localhost:${1}/
        ssh -N -L ${1}:127.0.0.1:${1} username@epyc.astro.washington.edu
}
```

What tables does AXS know about?

# Load ZTF data

In [5]:
ztf = catalog.load('ztf_mar19_all')

# Plotting

In [6]:
def plot_lightcurve(row):
    plt.figure(figsize=(8, 6))

    for filterid in np.unique(row['filterid']):
        cut = (
            (np.array(row['filterid']) == filterid)
            & (np.array(row['catflags']) == 0.)
        )
        def cc(x):
            return np.array(x)[cut]
        plt.errorbar(cc(np.array(row['mjd'])), cc(row['psfmag']), cc(row['psfmagerr']), fmt='o', c='C%d' % filterid, label='Filter %d' % filterid)

    plt.xlabel('mjd')
    plt.ylabel('Magnitude')
    plt.legend()
    plt.title('matchid %d' % row['matchid'])
    plt.gca().invert_yaxis()

# Cython setup

In [7]:
class cython_function():
    def __init__(self, module, name):
        self.module = module
        self.name = name
        self.function = None
        
    def load_function(self):
        import pyximport
        pyximport.install(reload_support=True)
        self.function = getattr(__import__(self.module), self.name)
        
    def __call__(self, *args, **kwargs):
        if self.function is None:
            self.load_function()

        return self.function(*args, **kwargs)
    
    def __getstate__(self):
        # Don't return the module so that each node has to recompile it itself.
        state = self.__dict__.copy()
        state['function'] = None
        return state

In [8]:
group_observations = cython_function('dipper', 'group_observations')
detect_dippers = cython_function('dipper', 'detect_dippers')

In [9]:
# Run this cell to recompile the cython code whenever needed.
from importlib import reload
import sys
import pyximport
pyximport.install(reload_support=True)

try:
    del sys.modules['dipper']
except KeyError:
    pass
import dipper

detect_dippers.function = None
group_observations.function = None

# Wrappers

In [10]:
def detect_dippers_row(row, verbose=False, return_mjd=False):
    return detect_dippers(row['mjd'], row['filterid'], row['psfmag'],
                          row['psfmagerr'], row['xpos'], row['ypos'],
                          row['catflags'],
                          verbose=verbose, return_mjd=return_mjd)

In [11]:
# Create a UDF for spark
detect_dippers_udf = sparkfunc.udf(detect_dippers, returnType=pyspark_types.FloatType())

# Run the spark query

## Run and save the query

In [None]:
%%time

# Run on spark
res = (
    ztf.region(ra1=270, ra2=310, dec1=-10, dec2=40)
    #ztf.region(ra1=295, ra2=296, dec1=20, dec2=21)
    .exclude_duplicates()
    .where(sparkfunc.col("nobs_avail") > 20)
    .select(
        '*',
        detect_dippers_udf(ztf['mjd'], ztf['filterid'], ztf['psfmag'], ztf['psfmagerr'], ztf['xpos'], ztf['ypos'], ztf['catflags']).alias('score')
    )
    .where(sparkfunc.col("score") > 2.)
    #.collect()
    .write.parquet('./query_high_cadence_2.parquet')
)

#print(len(res))

#import pickle
#pickle.dump(res, open('test_query.pkl', 'wb'))

## Load the query results

In [12]:
query = spark_session.read.parquet('/epyc/data/boyajian/saved_queries/query_high_cadence_2.parquet')

In [13]:
query_df = query.toPandas()

In [14]:
new_scores = []
for idx, row in tqdm.tqdm(query_df.iterrows()):
    new_scores.append(detect_dippers_row(row))

query_df['new_score'] = new_scores

42119it [01:01, 684.33it/s]


In [15]:
sort_df = query_df.sort_values('new_score', ascending=False)

In [23]:
def print_links(row):
    print("http://simbad.u-strasbg.fr/simbad/sim-coo?Coord=%.6f%+.6f&CooFrame=FK5&CooEpoch=2000&CooEqui=2000&CooDefinedFrames=none&Radius=20&Radius.unit=arcsec&submit=submit+query&CoordList=" % (row['ra'], row['dec']))
    print("RA+Dec: %.6f%+.6f" % (row['ra'], row['dec']))
    print("RA:     %.6f" % row['ra'])
    print("Dec:    %.6f" % row['dec'])

def show_lightcurve(idx):
    row = sort_df.iloc[idx]  
    print_links(row)    
    plot_lightcurve(row)

    print("Score:  %.3f" % detect_dippers_row(row))

In [22]:
from ipywidgets import interact, IntSlider
interact(show_lightcurve, idx=IntSlider(min=0, max=100, value=0))

<function __main__.show_lightcurve>