In [None]:
import axs
import numpy as np
from astropy.io import fits
import matplotlib
import matplotlib.pyplot as plt

import os
from pyspark.sql.functions import size as spark_size
from matplotlib.backends.backend_pdf import PdfPages
from pyspark.sql.functions import min as spark_min
from pyspark.sql.functions import max as spark_max

matplotlib.rcParams['figure.dpi'] = 120

In [None]:
def spark_start(project_path, metastore=None):
    from pyspark.sql import SparkSession

    warehouse_location = os.path.join(project_path, 'spark-warehouse')

    local_dir = os.path.join(project_path, 'spark-tmp')

    spark = ( 
            SparkSession.builder
            .appName("LSD2")
            .config("spark.sql.warehouse.dir", warehouse_location)
            .config('spark.master', "local[4]")
            .config('spark.driver.memory', '6G') # 128
            .config('spark.local.dir', local_dir)
            .config('spark.memory.offHeap.enabled', 'true')
            .config('spark.memory.offHeap.size', '4G') # 256
            .config("spark.sql.execution.arrow.enabled", "true")
            .config("spark.driver.maxResultSize", "6G")
            .config("spark.driver.extraJavaOptions", f"-Dderby.system.home={metastore}")
            .enableHiveSupport()
            .getOrCreate()
                    )   

    return spark

spark_session = spark_start("/epyc/users/ctslater")

catalog = axs.AxsCatalog(spark_session)

In [None]:
ztf = catalog.load("ztf_1am_lc")
sesar_axs = catalog.load("sesar_rrlyrae")

In [None]:
%%time
matched = sesar_axs.crossmatch(ztf)
matched_filtered = (matched.select("ra", "dec", "matchid", "Per", "weightedmeanmag", "filterid", "mjd", "psfflux")
                    .where((spark_size(matched['mjd']) > 5) &
                           ((matched['S3ab'] > 0.8) | (matched['S3c'] > 0.8))
                            ))

In [None]:
results = matched_filtered.head(200)

In [None]:
%%capture
# Supress all output from this. This is a hack.


figures = []
for n in range(len(results)//4):

    fig, axes = plt.subplots(2, 2)
    for m, ax in enumerate(axes.flatten()):
        result_id = 4*n + m
        this_source = results[result_id]

        sel, = np.where(np.array(this_source['filterid']) == 1)
        if(len(sel) > 0):
            ax.plot(np.array(this_source['mjd'])[sel]/this_source["Per"] % 1,
                     np.array(this_source['psfflux'])[sel], '.')

        sel, = np.where(np.array(this_source['filterid']) == 2)
        if(len(sel) > 0):
            ax.plot(np.array(this_source['mjd'])[sel]/this_source["Per"] % 1,
                     np.array(this_source['psfflux'])[sel], '.')


        flux_std = np.std(this_source['psfflux'])
        flux_mean = np.mean(this_source['psfflux'])
        ax.set_ylim(flux_mean - 3*flux_std, flux_mean + 3*flux_std)
        ax.set_xlim(0, 1)
        ax.set_xlabel("Phase")
        ax.set_ylabel("Flux")
        ax.yaxis.set_ticklabels("")
        ax.xaxis.set_ticklabels(["0", "", "", "", "", "1"])

        ax.text(0.85, 0.95, "{:0.1f}".format(this_source['weightedmeanmag']),
                 fontsize=8, verticalalignment="top",
                 transform=ax.transAxes)
        ax.text(0.05, 0.95, "{:d}".format(result_id),
                 fontsize=8, verticalalignment="top",
                 transform=ax.transAxes)
        
    figures.append(fig)

with PdfPages('ztf_lyrae.pdf') as pdf:            
    for fig in figures:
        pdf.savefig(fig);