In [1]:
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import ipywidgets as widgets

import axs
import pyspark.sql.functions as sparkfunc

In [2]:
#https://github.com/keatonb/dfbrowser
from dfbrowser import dfbrowser

In [3]:
from pyspark.sql import SparkSession
spark = ( 
    SparkSession
    .builder
    .config("spark.master", "local[*]")
    .config("spark.driver.memory", "120g")
    .config("spark.driver.maxResultSize", "0")
    .enableHiveSupport()
    .getOrCreate()
)
import axs
catalog = axs.AxsCatalog(spark)

In [4]:
steven = catalog.load("stevengs_cut_wtf")

#Join Gaia stuff for CMD, from Jim's notebook
gaia = catalog.load('gaia_dr2_1am_dup')

In [5]:
base_df = steven.crossmatch(gaia.select('ra', 'dec', 'zone', 'dup', 'parallax',
                                               'parallax_error', 'phot_g_mean_mag', 'bp_rp', 
                                               'phot_g_mean_flux', 'phot_g_mean_flux_error',
                                               'phot_bp_mean_flux', 'phot_bp_mean_flux_error', 
                                               'phot_rp_mean_flux', 'phot_rp_mean_flux_error', 
                                               'phot_variable_flag')).toPandas()

In [6]:
base_df["abs_mag"] = base_df['phot_g_mean_mag'] - 5 * np.log10(1000 / base_df['parallax']) + 5

base_df["parallaxsnr"] = ((base_df['parallax'] / base_df['parallax_error']) > 5)

  result = getattr(ufunc, method)(*inputs, **kwargs)


In [7]:
df_dict = {
    'skew_r': '8_5_20_stevengs_skew_normal_fits_r_band',
    'skew_g': '8_5_20_stevengs_skew_normal_fits_g_band',
    'skew_i': '8_5_20_stevengs_skew_normal_fits_i_band',
    'tophat_r': '8_5_20_stevengs_top_hat_fits_r_band',
    'tophat_g': '8_5_20_stevengs_top_hat_fits_g_band',
    'tophat_i': '8_5_20_stevengs_top_hat_fits_i_band',
}

In [8]:
pandas_dfs = {}
for label, name in df_dict.items():
    pandas_dfs[label] = catalog.load(name).toPandas()

In [9]:
merged_df = base_df[base_df['dup'] == 0]
for channel in ['r', 'g', 'i']:
    
    df_1 = pandas_dfs[f'skew_{channel}']
    df_2 = pandas_dfs[f'tophat_{channel}']
    
    df_1 = df_1[df_1['dup'] == 0]
    df_2 = df_2[df_2['dup'] == 0]
    
    channel_df = pd.merge(
        df_1,
        df_2,
        how='inner',
        on='ps1_objid',
        suffixes=[f'_skew_{channel}', f'_tophat_{channel}']
    )
    merged_df = pd.merge(merged_df, channel_df, on='ps1_objid', how='left')

In [10]:
def extract_key(key, subkey):
    return np.array([i[subkey] if isinstance(i, sparktypes.Row) else i for i in merged_df[key]])

def extract_key_item(key, subkey, item):
    return np.array([i[subkey][item] if isinstance(i, sparktypes.Row) else i for i in merged_df[key]])

In [11]:
import pyspark.sql.types as sparktypes

chisq_skew_g = extract_key('model_error_around_dip_g_skew_g', 'reduced_sum_square_error')
chisq_skew_r = extract_key('model_error_around_dip_r_skew_r', 'reduced_sum_square_error')
chisq_skew_i = extract_key('model_error_around_dip_i_skew_i', 'reduced_sum_square_error')
chisq_tophat_g = extract_key('model_error_around_dip_g_tophat_g', 'reduced_sum_square_error')
chisq_tophat_r = extract_key('model_error_around_dip_r_tophat_r', 'reduced_sum_square_error')
chisq_tophat_i = extract_key('model_error_around_dip_i_tophat_i', 'reduced_sum_square_error')

skew_g = extract_key_item('fit_g_skew_g', 'popt', 0)
xscale_g = extract_key_item('fit_g_skew_g', 'popt', 2)
yscale_g = extract_key_item('fit_g_skew_g', 'popt', 3)
skew_r = extract_key_item('fit_r_skew_r', 'popt', 0)
xscale_r = extract_key_item('fit_r_skew_r', 'popt', 2)
yscale_r = extract_key_item('fit_r_skew_r', 'popt', 3)
skew_i = extract_key_item('fit_i_skew_i', 'popt', 0)
xscale_i = extract_key_item('fit_i_skew_i', 'popt', 2)
yscale_i = extract_key_item('fit_i_skew_i', 'popt', 3)

In [12]:
def skew_normal(x, skew, loc, xscale, yscale, offset):
    from scipy.stats import skewnorm
    _dist = skewnorm(skew, loc=loc, scale=xscale)
    return yscale * _dist.pdf(x) + offset

def top_hat(x, loc, width, depth, offset):
    import numpy as np
    x = np.array(x)
    left = loc - width / 2
    right = loc + width / 2
    outside = (x < left) | (x > right)
    inside = np.logical_not(outside)
    
    y = np.zeros(x.shape)
    y[outside] = offset
    y[inside] = offset + depth
    
    return y

In [14]:
sig = extract_key('dip', 'significance')

In [15]:
cut = (
    # g
    (
        (chisq_skew_g < 5)
        & (np.abs(skew_g) > 2)
        & (np.abs(yscale_g) < 4)
        & (np.abs(xscale_g) < 50)
        & (chisq_tophat_g - chisq_skew_g > 2)
    )
    # r
    | (
        (chisq_skew_r < 5)
        & (np.abs(skew_r) > 2)
        & (np.abs(yscale_r) < 4)
        & (np.abs(xscale_r) < 50)
        & (chisq_tophat_r - chisq_skew_r > 2)
    )
    # i
    | (
        (chisq_skew_i < 5)
        & (np.abs(skew_i) > 2)
        & (np.abs(yscale_i) < 4)
        & (np.abs(xscale_i) < 50)
        & (chisq_tophat_i - chisq_skew_i > 2)
    )
)

cut_df = merged_df.iloc[cut]



In [16]:
sig = extract_key('dip', 'significance')
sig_df = cut_df.iloc[np.argsort(sig[cut])[::-1]].reset_index()

In [21]:
sig_df

Unnamed: 0,index,ps1_objid,mean_mag_g,mean_mag_r,mean_mag_i,ra_stddev,dec_stddev,ps1_gMeanPSFMag,ps1_rMeanPSFMag,ps1_iMeanPSFMag,...,mag_i_tophat_i,magerr_i_tophat_i,dip_tophat_i,window_i_tophat_i,fit_i_tophat_i,dip_window_i_tophat_i,model_error_in_dip_i_tophat_i,around_dip_window_i_tophat_i,model_error_around_dip_i_tophat_i,model_error_i_tophat_i
0,57044,167983111466390337,17.452377,16.498033,,0.000016,0.000015,17.389601,16.557699,16.144899,...,,,,,,,,,,
1,59004,133592971966272490,18.383808,17.189658,,0.000020,0.000018,18.364901,17.327999,16.817801,...,,,,,,,,,,
2,59005,133592971966272490,18.383808,17.189658,,0.000020,0.000018,18.364901,17.327999,16.817801,...,,,,,,,,,,
3,44042,129892859318332866,18.419134,16.975537,,0.000043,0.000033,18.315901,17.081699,16.500700,...,,,,,,,,,,
4,409,131123129860725804,13.776432,13.122984,12.872258,0.000022,0.000020,13.749900,13.475000,13.135000,...,"[12.8900146484375, 12.907140731811523, 12.8647...","[0.012492096982896328, 0.012483634054660797, 0...","(58328.86328125, 58327.015625, 58331.203125, 4...","([58340.1640625, 58312.37890625, 58316.4882812...","((OK, True, 0.0004851818084716797), [1.0, 1.0,...","([], [], [])","(0.0, -0.0)","([58340.1640625, 58312.37890625, 58316.4882812...","(50.127105712890625, 16.709035873413086)","(690.1306762695312, 18.161333084106445)"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1074,33283,135823305178079985,15.352830,14.623342,14.327940,0.000020,0.000014,15.315900,14.677600,14.401100,...,"[14.344916343688965, 14.328980445861816, 14.30...","[0.014021583832800388, 0.014018900692462921, 0...","(58639.22265625, 58618.8671875, 58646.1796875,...","([58709.42578125, 58713.4296875, 58718.3867187...","((OK, True, 0.0013611316680908203), [1.0, 1.0,...","([], [], [])","(0.0, -0.0)","([58709.42578125, 58713.4296875, 58718.3867187...","(72.3078842163086, 3.615394353866577)","(242.51852416992188, 4.850370407104492)"
1075,55939,157612721562077645,17.555750,16.971858,16.793944,0.000023,0.000026,17.536400,17.019600,16.832701,...,"[16.774261474609375, 16.806678771972656, 16.79...","[0.015924524515867233, 0.01613052748143673, 0....","(58372.7421875, 58369.828125, 58376.296875, 6....","([58347.36328125, 58351.328125, 58362.28515625...","((OK, True, 0.0004892349243164062), [58373.062...","([58371.2890625, 58375.32421875], [16.80988693...","(2.190582513809204, -1.095291256904602)","([58347.36328125, 58351.328125, 58362.28515625...","(14.29361629486084, 2.8587231636047363)","(572.348388671875, 9.231425285339355)"
1076,14149,87091537548223762,15.585699,14.469614,,0.000024,0.000035,15.525100,14.533100,14.139800,...,,,,,,,,,,
1077,41212,123353306121844339,19.279819,18.207027,17.756914,0.000037,0.000040,19.422199,18.647301,18.058800,...,"[17.66988754272461, 17.679216384887695, 17.795...","[0.02514365315437317, 0.025281058624386787, 0....","(58326.90625, 58319.78515625, 58331.140625, 11...","([58351.3515625, 58375.31640625, 58362.3125, 5...","((OK, True, 0.0005342960357666016), [58325.460...","([58320.23046875, 58329.37890625], [17.7751865...","(6.750715255737305, -3.3753576278686523)","([58351.3515625, 58375.31640625, 58362.3125, 5...","(44.193687438964844, 4.419368743896484)","(403.549560546875, 9.384873390197754)"


In [27]:
colormap = {
    'g': 'tab:green',
    'r': 'tab:red',
    'i': 'tab:purple',
}

fig, subplots = plt.subplots(3, 1, figsize=(8, 9), sharex=True, gridspec_kw={'hspace': 0.05})

def plot_row(row):
    
    for ax in subplots:
        ax.clear()
    
    dip = row['dip']
    
    start_mjd = dip['window_start_mjd']
    end_mjd = dip['window_end_mjd']
    pad = end_mjd - start_mjd

    x = np.linspace(start_mjd - pad, end_mjd + pad, 1000)

    for subplot, channel in zip(subplots, ['r', 'g', 'i']):
        c = colormap[channel]
        
        mjd = np.array(row[f'mjd_{channel}'])
        mag = np.array(row[f'mag_{channel}'])
        magerr = np.array(row[f'magerr_{channel}'])
        
        mask = (mjd > start_mjd - pad) & (mjd < end_mjd + pad)

        subplot.errorbar(mjd[mask], mag[mask], magerr[mask], fmt='o', c=c, label=f'{channel} observations')

        fit = row[f'fit_{channel}_skew_{channel}']
        if fit is not np.nan:
            model = skew_normal(x, *fit['popt'])
            chisq = row[f"model_error_around_dip_{channel}_skew_{channel}"][1]
            subplot.plot(x, model, label=f'Skew normal fit {channel} - {chisq:.2f}', c='C0')

        fit = row[f'fit_{channel}_tophat_{channel}']
        if fit is not np.nan:
            model = top_hat(x, *fit['popt'])
            chisq = row[f"model_error_around_dip_{channel}_tophat_{channel}"][1]
            subplot.plot(x, model, label=f'Top hat fit {channel} - {chisq:.2f}', c='C1')        

        subplot.invert_yaxis()
        subplot.legend()
        subplot.set_xlim(start_mjd - pad, end_mjd + pad)

        subplot.axvline(start_mjd, c='k', ls='--')
        subplot.axvline(end_mjd, c='k', ls='--')

        subplot.set_ylabel('Magnitude')
        
    subplots[2].set_xlabel('MJD')
    subplots[0].set_title(row['ps1_objid'])
    fig.canvas.draw()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [22]:
browse = dfbrowser(sig_df,alpha=0.4,funct=plot_row)

VBox(children=(HBox(children=(Dropdown(description='x var:', options=('index', 'ps1_objid', 'mean_mag_g', 'mea…