In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import glob
from astropy import log
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt

import copy
from astropy.visualization import quantity_support
from astropy.table import Table, QTable
import matplotlib
import astropy.units as u
from astroduet.config import Telescope
from astroduet.background import background_pixel_rate
font = {'size'   : 22}
from astroduet.models import Simulations, fits_file, load_model_ABmag, load_model_fluence
matplotlib.rcParams.update({'font.size': 22})
from astroduet.lightcurve import get_lightcurve, lightcurve_through_image
import astroduet.image_utils
from astroduet.utils import duet_fluence_to_abmag
import seaborn as sns
from scipy.interpolate import interp1d



In [None]:
# Do it only once
# sims = Simulations()
# sims.parse_emgw()
# sims.parse_sne_bsg()
# sims.parse_sne_ysg()
# sims.parse_sne_rsg()


In [None]:
ytime, ab1, ab2 = load_model_ABmag('ysg400',
                                    dist=10 * u.Mpc)
rtime, ab1, ab2 = load_model_ABmag('rsg400',
                                    dist=10 * u.Mpc)
btime, ab1, ab2 = load_model_ABmag('bsg80',
                                    dist=10 * u.Mpc)


In [None]:
btime[-1], ytime[-1], rtime[-1]

In [None]:
def create_and_plot_lc(input_lc_file, distance=100e6*u.pc, **kwargs):
    abtime, ab1, ab2 = load_model_ABmag(input_lc_file,
                                        dist=distance)
    model_lc_table_ab = QTable({'time': abtime, 'mag_D1': ab1, 'mag_D2':ab2})
    lightcurve = get_lightcurve(input_lc_file, distance=distance, **kwargs)
    plt.figure(figsize=(15, 8))
    gs = plt.GridSpec(1, 1, hspace=0)
    ax0 = plt.subplot(gs[0])
    good = (lightcurve['snr_D1'] > 1) | (lightcurve['snr_D2'] > 1)
    lightcurve = lightcurve[good]
    ax0.errorbar(lightcurve['time'].value / 86400, lightcurve['mag_D1'].value, 
                 fmt='o', markersize=2, yerr=lightcurve['mag_D1_err'].value, label='D1')
    ax0.errorbar(lightcurve['time'].value / 86400, lightcurve['mag_D2'].value, 
                 fmt='o', markersize=2, yerr=lightcurve['mag_D2_err'].value, label='D2')
    
    ax0.plot(model_lc_table_ab['time'] / 86400, model_lc_table_ab[f'mag_D1'])
    ax0.plot(model_lc_table_ab['time'] / 86400, model_lc_table_ab[f'mag_D2'])

    ax0.set_ylabel("AB mag")
    ax0.set_xlabel("Time (d)")
    ax0.set_xlim([lightcurve['time'][0].value / 86400, lightcurve['time'][-1].value / 86400])
    ymin = min(lightcurve['mag_D1'].value.min(), lightcurve['mag_D2'].value.min()) - 1
    ymax = max(lightcurve['mag_D1'].value.max(), lightcurve['mag_D2'].value.max()) + 1
    # Inverted ax for magnitude
    ax0.set_ylim([ymax, ymin])
#    ax1.semilogx();
    ax0.legend()
    
    
def create_and_plot_lc_snr(input_lc_file, distance=100e6*u.pc, **kwargs):
    abtime, ab1, ab2 = load_model_ABmag(input_lc_file,
                                        dist=distance)
    model_lc_table_ab = QTable({'time': abtime, 'mag_D1': ab1, 'mag_D2':ab2})
    lightcurve = get_lightcurve(input_lc_file, distance=distance, **kwargs)
    plt.figure(figsize=(15, 8))
    gs = plt.GridSpec(2, 1, hspace=0)
    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1], sharex=ax0)
    good = (lightcurve['snr_D1'] > 1) | (lightcurve['snr_D2'] > 1)
    lightcurve = lightcurve[good]
    ax0.errorbar(lightcurve['time'].value / 86400, lightcurve['mag_D1'].value, 
                 fmt='o', markersize=2, yerr=lightcurve['mag_D1_err'].value, label='D1')
    ax0.errorbar(lightcurve['time'].value / 86400, lightcurve['mag_D2'].value, 
                 fmt='o', markersize=2, yerr=lightcurve['mag_D2_err'].value, label='D2')
    ax1.scatter(lightcurve['time'].value / 86400, lightcurve['snr_D1'].value, s=2)
    ax1.scatter(lightcurve['time'].value / 86400, lightcurve['snr_D2'].value, s=2)
    
    ax0.plot(model_lc_table_ab['time'] / 86400, model_lc_table_ab[f'mag_D1'])
    ax0.plot(model_lc_table_ab['time'] / 86400, model_lc_table_ab[f'mag_D2'])

    ax0.set_ylabel("AB mag")
    ax1.set_ylabel("S/R")
    ax1.set_xlabel("Time (d)")
    ax0.set_xlim([lightcurve['time'][0].value / 86400, lightcurve['time'][-1].value / 86400])
    ymin = min(lightcurve['mag_D1'].value.min(), lightcurve['mag_D2'].value.min()) - 1
    ymax = max(lightcurve['mag_D1'].value.max(), lightcurve['mag_D2'].value.max()) + 1
    # Inverted ax for magnitude
    ax0.set_ylim([ymax, ymin])
#    ax1.semilogx();
    ax0.legend()
    

In [None]:
create_and_plot_lc('rsg600', distance=500*u.Mpc)

In [None]:
create_and_plot_lc('rsg400', distance=300*u.Mpc)

Now specifying the observing window...

In [None]:
def get_random_galaxy(distance, duet, tolerance=10 * u.Mpc, average_gal=False):
    from astroduet.models import load_bai
    from astropy.table import QTable
    from astroduet.utils import galex_to_duet
    bai = QTable(load_bai())
    bai['BAI1'], bai['BAI2'] = galex_to_duet([bai['SURFFUV'].value, bai['SURFNUV'].value], duet=duet)
    oversample = 6
    pixel_size_init = duet.pixel / oversample
    galaxies_within_distance = \
        (bai['DIST'] >= distance - tolerance)&(bai['DIST'] < distance + tolerance)
    if not np.any(galaxies_within_distance):
        log.warn(f"No galaxies in BAI catalogue between "
                 f"{(distance - tolerance).to(u.Mpc).value:.2f} "
                 f"and {(distance + tolerance).to(u.Mpc).value:.2f} Mpc."
                 " Choosing randomly from sample > 150 Mpc")
        galaxies_within_distance = bai['DIST'] > 150 * u.Mpc
    
    galaxy = np.random.choice(bai[galaxies_within_distance])
    
    rad = (galaxy['RAD'] * u.arcsec * galaxy['DIST'] * u.Mpc / distance).to(u.arcsec)
    rad_pix = rad / pixel_size_init

    if galaxy['MORPH'] >=0: 
        n=1
    else: 
        n = 4
    theta = np.random.uniform(0, np.pi)
    ellip = np.random.uniform(0.1, 1)
    pos = np.random.uniform(0.1, 3)
    if average_gal:
        pos = 1
        galaxy['BAI1'] = np.median(bai['BAI1'])
        galaxy['BAI2'] = np.median(bai['BAI2'])
    d = pos * rad / pixel_size_init
    x_0 = d * np.cos(theta)
    y_0 = d * np.sin(theta)
    
    gal_params = dict(r_eff=rad_pix.value, n=n, theta=theta, ellip=ellip, x_0=x_0.value, y_0=y_0.value)
    gal_params1 = {'magnitude': galaxy['BAI1'], **gal_params}
    gal_params2 = {'magnitude': galaxy['BAI2'], **gal_params}
    return gal_params1, gal_params2

In [None]:
def red_chi_sq(f, x, s, dof=None):
    if dof is None:
        dof = len(f) - 1
    return np.sum((f - x)**2 / s**2) / dof


def plot_realistic_lightcurve(input_lc_file, exposure, label=None, debug=False,  
                              observing_windows=np.array([[0, 30000]]) * u.s, 
                              final_resolution=1200 * u.s, distance=150e6*u.pc,
                              galaxy_type=None, psf_correction = 0.937,
                              show_non_rebinned=False):
    duet = Telescope()
    # Set debug to True to dump all intermediate images.
    galaxy_par = None
    if galaxy_type == "random":
        galaxy_par = get_random_galaxy(distance, duet)[0]
        galaxy_type = "custom"
    elif galaxy_type == "average":
        galaxy_par = get_random_galaxy(distance, duet, average_gal=True)[0]
        galaxy_type = "custom"

    lightcurve_init = \
        get_lightcurve(input_lc_file, exposure=exposure, observing_windows=observing_windows,
                       distance=distance)
    lightcurve_rebin = lightcurve_through_image(lightcurve_init, exposure=exposure, 
                                                final_resolution=final_resolution, 
                                                debug=debug, gal_params=galaxy_par, gal_type=galaxy_type)
    if show_non_rebinned:
        lightcurve = lightcurve_through_image(lightcurve_init, exposure=exposure, debug=debug,
                                              gal_params=galaxy_par, gal_type=galaxy_type)

    model_lc_table_fl = QTable(load_model_fluence(input_lc_file,
                                                  dist=distance))
    
    model_lc_table_AB = QTable(load_model_ABmag(input_lc_file,
                                                dist=distance))
    
    model_lc_table_fl['fluence_D1'] = model_lc_table_fl['fluence_D1'].to(u.ph / u.cm**2 / u.s)
    model_lc_table_fl['fluence_D2'] = model_lc_table_fl['fluence_D2'].to(u.ph / u.cm**2 / u.s)
    
    model_fun_D1 = interp1d(model_lc_table_fl['time'], model_lc_table_fl['fluence_D1'], bounds_error=False, fill_value='extrapolate')
    model_fun_D2 = interp1d(model_lc_table_fl['time'], model_lc_table_fl['fluence_D2'], bounds_error=False, fill_value='extrapolate')

    plt.figure(figsize=(15, 10))
    gs = plt.GridSpec(2, 1, hspace=0)
    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1], sharex=ax0)
    plt.suptitle(label)

    ax0.plot(model_lc_table_fl['time'].value / 86400, 
             model_lc_table_fl['fluence_D1'].to(u.ph / u.cm**2 / u.s).value, label=f"init D1",
             color='r')
    ax0.plot(model_lc_table_fl['time'].value / 86400, 
            model_lc_table_fl['fluence_D2'].to(u.ph / u.cm**2 / u.s).value, label=f"init D2",
             color='b')
    ax1.axhline(0, color='k', ls='--')
      
    model_lc_table_fl.write(label + '_model.csv')
    model_lc_table_AB.write(label + '_modelAB.csv')
    if not show_non_rebinned:
        lightcurves = [lightcurve_rebin]
        exposures = [final_resolution]
        filenames = [label + '_rebin.csv']
    else:
        lightcurves = [lightcurve, lightcurve_rebin]
        exposures = [exposure, final_resolution]
        filenames = [label + '.csv', label + '_rebin.csv']
    
    for lc, expo, filename in zip(lightcurves, exposures, filenames):
        alpha = 1
        size = 10
        if expo == exposure:
            alpha = 0.3
            size = 5
        print(lc, expo, filename)
        good1 = (lc['fluence_D1_fit'] > 0)&(lc['fluence_D1_fiterr'] < lc['fluence_D1_fit'])&(lc['fluence_D1_fiterr'] > 0)
        good2 = (lc['fluence_D2_fit'] > 0)&(lc['fluence_D2_fiterr'] < lc['fluence_D2_fit'])&(lc['fluence_D2_fiterr'] > 0)
        good = good1&good2
        
        redchi = red_chi_sq(lc['fluence_D1_fit'][good].value / psf_correction, 
                            model_fun_D1(lc['time'].value[good]), 
                            lc['fluence_D1_fiterr'][good].value)
        ax0.errorbar(lc['time'].value[good] / 86400, 
                     lc['fluence_D1_fit'][good].value / psf_correction, 
                     yerr=lc['fluence_D1_fiterr'][good].value, fmt='o', label=f"D1, {expo}, {redchi:.2f}",
                     alpha=alpha, color='r', markersize=size)
        
        redchi = red_chi_sq(lc['fluence_D2_fit'][good].value / psf_correction, 
                            model_fun_D2(lc['time'].value[good]), 
                            lc['fluence_D2_fiterr'][good].value)
        ax0.errorbar(lc['time'].value[good] / 86400, 
                     lc['fluence_D2_fit'][good].value / psf_correction, 
                     yerr=lc['fluence_D2_fiterr'][good].value, fmt='s', label=f"D2, {expo}, {redchi:.2f}",
                     alpha=alpha, color='b', markersize=size)
        
        ax1.errorbar(lc['time'].value[good] / 86400, 
                     lc['fluence_D1_fit'][good].value / psf_correction - model_fun_D1(lc['time'].value[good]), 
                     yerr=lc['fluence_D1_fiterr'][good].value, fmt='o',
                     alpha=alpha, color='r', markersize=size)
        ax1.errorbar(lc['time'].value[good] / 86400, 
                     lc['fluence_D2_fit'][good].value / psf_correction - model_fun_D2(lc['time'].value[good]), 
                     yerr=lc['fluence_D2_fiterr'][good].value, fmt='s', 
                     alpha=alpha, color='b', markersize=size)
        
        lc[good].write(filename)

    ax0.set_xlabel("Time (d)")
    ax0.set_ylabel("Fluence (ph / cm^2 / s)")
    ax0.set_xlim([lightcurve_rebin['time'].value.min()/86400 - 0.1, 
              lightcurve_rebin['time'].value.max()/86400 + 0.1])
    
    ref_err = np.median(lightcurve_rebin['fluence_D1_fiterr'][lightcurve_rebin['fluence_D1_fiterr'].value > 0]).value
    ax1.set_ylim([-ref_err * 10, ref_err * 10])
    
    ax0.legend()


In [None]:
# Average host:
# 

In [None]:
for dist_mpc in range(100, 600, 100):
    for model in ["kilonova_0.04.dat", "shock_5e10.dat"]:
        plot_realistic_lightcurve(model, 300 * u.s,  
                                  observing_windows=np.array([[0, 86400]]) * u.s, 
                                  final_resolution=2400 * u.s, 
                                  distance=dist_mpc*u.Mpc, label=f"{model}, {dist_mpc}Mpc", debug=True,
                                  show_non_rebinned=True)

In [None]:
for dist_mpc in range(50, 200, 50):
    for model in ["kilonova_0.04.dat", "shock_5e10.dat"]:
        plot_realistic_lightcurve(model, 300 * u.s,  
                                  observing_windows=np.array([[0, 86400]]) * u.s, 
                                  final_resolution=2400 * u.s, 
                                  distance=dist_mpc*u.Mpc, label=f"{model}, {dist_mpc}Mpc - galaxy", debug=True,
                                  show_non_rebinned=True, galaxy_type='average')

In [None]:
plot_realistic_lightcurve("shock_5e10.dat", 300 * u.s,  
                          observing_windows=np.array([[30 * 60, 30000]]) * u.s, 
                          final_resolution=2400 * u.s, 
                          distance=70e6*u.pc, label="shock_5e10, 70Mpc", debug=True)

In [None]:
plot_realistic_lightcurve("shock_5e10.dat", 300 * u.s,  
                          observing_windows=np.array([[30 * 60, 30000]]) * u.s, 
                          final_resolution=2400 * u.s, 
                          distance=70e6*u.pc, label="shock_5e10, 70Mpc, with galaxy", debug=True,
                          galaxy_type="random")

In [None]:
plot_realistic_lightcurve("kilonova_0.04.dat", 300 * u.s,  
                          observing_windows=np.array([[30 * 60, 900000]]) * u.s, 
                          final_resolution=4800 * u.s, 
                          distance=150e6*u.pc, label="blukn_04, 150Mpc", debug=True)

In [None]:
plot_realistic_lightcurve("rsg400", 300 * u.s,  
                          observing_windows=np.array([[0, 500000]]) * u.s, 
                          final_resolution=1200 * u.s, 
                          distance=200*u.Mpc, label="RSG 400 Ro, 200Mpc", debug=True,
                          show_non_rebinned=False)

In [None]:
plot_realistic_lightcurve("ysg400", 300 * u.s,  
                          observing_windows=np.array([[0, 500000]]) * u.s, 
                          final_resolution=1200 * u.s, 
                          distance=200*u.Mpc, label="YSG 400 Ro, 200Mpc", debug=True,
                          show_non_rebinned=False)

In [None]:
for dist_mpc in [100, 150, 200]:
    for model in ['bsg80', 'rsg400', 'ysg400']:
        star = model.upper().replace('SG', 'SG ')
        plot_realistic_lightcurve(model, 300 * u.s,  
                                  observing_windows=np.array([[0, 500000]]) * u.s, 
                                  final_resolution=2400 * u.s, 
                                  distance=dist_mpc*u.Mpc, label=f"{star} Ro, {dist_mpc}Mpc", debug=True,
                                  show_non_rebinned=True)

In [None]:
for dist_mpc in [300, 400, 500]:
    for model in ['bsg80', 'rsg400', 'ysg400']:
        star = model.upper().replace('SG', 'SG ')
        plot_realistic_lightcurve(model, 300 * u.s,  
                                  observing_windows=np.array([[0, 500000]]) * u.s, 
                                  final_resolution=2400 * u.s, 
                                  distance=dist_mpc*u.Mpc, label=f"{star} Ro, {dist_mpc}Mpc", debug=True,
                                  show_non_rebinned=True)

In [None]:
for dist_mpc in [600, 700, 800]:
    for model in ['bsg80', 'rsg400', 'ysg400']:
        star = model.upper().replace('SG', 'SG ')
        plot_realistic_lightcurve(model, 300 * u.s,  
                                  observing_windows=np.array([[0, 500000]]) * u.s, 
                                  final_resolution=2400 * u.s, 
                                  distance=dist_mpc*u.Mpc, label=f"{star} Ro, {dist_mpc}Mpc", debug=True,
                                  show_non_rebinned=True)

In [None]:
plot_realistic_lightcurve("rsg600", 300 * u.s,  
                          observing_windows=np.array([[0, 500000]]) * u.s, 
                          final_resolution=1500 * u.s, 
                          distance=370e6*u.pc, label="RSG 600 Ro, 370Mpc", debug=True)

In [None]:
## In case someone wants to take a look at the debug images...
## set debug=True in plot_realistic_lightcurves and look at them

def plot_debug_images(directory):
    import matplotlib.animation as animation
    from astroduet.image_utils import find, run_daophot
    from statsmodels.robust import mad
    img_pickles = glob.glob(os.path.join(directory, '*.p'))
    for img_pickle in img_pickles:
        with open(img_pickle, 'rb') as fobj:
            img = pickle.load(fobj)
        image1 = img['imgD1']
 
        plt.figure(figsize=(10, 10))
        plt.title(img_pickle)
        plt.imshow(image1.value)
        plt.colorbar()
        if 'ref' in img_pickle:
            ref_img = image1.value
    
    img_lc = Table.read(os.path.join(directory,'lightcurve.hdf5'))
    
    for line in img_lc:
        im = line['imgs_D1_bkgsub']
        time = line['time']
        fig = plt.figure(figsize=(3, 3))
        diff = im - ref_img
        plt.imshow(diff - np.median(diff), vmin=0, vmax=5* mad(diff.flatten()))
        plt.title(time)

plot_debug_images('debug_imgs_33386538')



In [None]:
def fit_lightcurve(lightcurve, label='lightcurve fit', solutions=None,
                   debug=False, additional_info_for_table=None,
                   fit_model_files=None, correct_model="", distance=None, 
                   psf_correction = 0.937):
    from astroduet.models import Simulations
    from scipy.optimize import curve_fit
    from scipy.interpolate import interp1d
    from astropy.visualization import quantity_support
    quantity_support()
    if additional_info_for_table is None:
        additional_info_for_table = {}
    if distance is None:
        distance = 100 * u.Mpc
    if solutions is None:
        add_names = [k for k, v in additional_info_for_table.items()]
        add_dtypes = [np.asarray(v).dtype if 'model' not in k else 'U30' for k, v in additional_info_for_table.items()]
        names = 'fit_model,D1,D2,ratio,D1_chisq,D2_chisq,ratio_chisq,D1_chisq_nofit,D2_chisq_nofit,ratio_chisq_nofit,ngood'.split(',') + add_names
        dtypes = dtype=['U30', float, float, float, float, float, float, float, float, float, int] + add_dtypes
            
        solutions = QTable(
            names=names, 
            dtype=dtypes)
    
    if fit_model_files is not None:
        lc_files = fit_model_files
    else:
        lc_files = Simulations().emgw_simulations

    fluence_D1 = lightcurve['fluence_D1_fit'] / psf_correction
    fluence_D2 = lightcurve['fluence_D2_fit'] / psf_correction
    snr_D1 = lightcurve['snr_D1']
    snr_D2 = lightcurve['snr_D2']
    good = ((fluence_D1 > 0)&(snr_D1 > 5))|((fluence_D2 > 0)&(snr_D2 > 5))
    
    lightcurve = lightcurve[good]
    if len(lightcurve) < 5:
        print("Lightcurve is invalid")
        for lc_file in lc_files:
            empty_row_dict = {'fit_model': lc_file, 'D1': 0, 'D2': 0, 'ratio': 0, 
                              'D1_chisq': -1, 'D2_chisq': -1, 'ratio_chisq': -1,
                              'D1_chisq_nofit': -1, 'D2_chisq_nofit': -1, 'ratio_chisq_nofit': -1,
                              'ngood': 0}
            solutions.add_row({**empty_row_dict,**additional_info_for_table})
        return solutions
    
    times = lightcurve['time']
    fluence_D1 = lightcurve['fluence_D1_fit'] / psf_correction
    fluence_D2 = lightcurve['fluence_D2_fit'] / psf_correction
    snr_D1 = lightcurve['snr_D1']
    snr_D2 = lightcurve['snr_D2']
    fluence_D1_err = lightcurve['fluence_D1_fiterr'] / psf_correction
    fluence_D2_err = lightcurve['fluence_D2_fiterr'] / psf_correction

    ratio = fluence_D2 / fluence_D1
    ratio_err = ratio * (fluence_D1_err / fluence_D1 +
                         fluence_D2_err / fluence_D2)
    
    if debug:
        plt.figure(figsize=(15, 15))
        plt.suptitle(label)
        gs = plt.GridSpec(3, 1)
        ax1 = plt.subplot(gs[0])
        ax2 = plt.subplot(gs[1], sharex=ax1)
        axr = plt.subplot(gs[2], sharex=ax1)
        ax1.errorbar(times, fluence_D1, yerr=fluence_D1_err, fmt='o', markersize=5, label='data')
        ax2.errorbar(times, fluence_D2, yerr=fluence_D2_err, fmt='o', markersize=5, label='data')
        axr.errorbar(times, ratio, yerr=ratio_err, fmt='o', markersize=5, label='data')

    for lc_file in lc_files[::]:
        model_lc_table_fl = QTable(load_model_fluence(lc_file, dist=distance))
        interpolated_lc_1 = interp1d(model_lc_table_fl['time'].to(u.s).value,
                           model_lc_table_fl['fluence_D1'].value, fill_value=0,
                           bounds_error=False)
        interpolated_lc_2 = interp1d(model_lc_table_fl['time'].to(u.s).value,
                           model_lc_table_fl['fluence_D2'].value, fill_value=0,
                           bounds_error=False)
        def interpolated_lc_ratio(time):
            return(interpolated_lc_2(time) / interpolated_lc_1(time))
    
        def constant_fit_fun_1(x, a):
            return a * interpolated_lc_1(x)
        def constant_fit_fun_2(x, a):
            return a * interpolated_lc_2(x)
        def constant_fit_fun_ratio(x, a):
            return a * interpolated_lc_ratio(x)
        
        d1_chisq_nofit = red_chi_sq(constant_fit_fun_1(times, 1), 
                              fluence_D1.value, fluence_D1_err.value, dof=len(fluence_D1) - 1)
        d2_chisq_nofit = red_chi_sq(constant_fit_fun_2(times, 1), 
                              fluence_D2.value, fluence_D2_err.value, dof=len(fluence_D2) - 1)
        ratio_chisq_nofit = red_chi_sq(constant_fit_fun_ratio(times, 1), 
                              ratio.value, ratio_err.value, dof=len(ratio) - 1)

        par1, pcov1 = curve_fit(constant_fit_fun_1, 
                                times, fluence_D1, 
                                sigma=fluence_D1_err, p0=[1])
        par2, pcov2 = curve_fit(constant_fit_fun_2, 
                                times, fluence_D2, 
                                sigma=fluence_D2_err, p0=[1])
        parr, pcovr = curve_fit(constant_fit_fun_ratio, 
                                times, ratio, sigma=ratio_err, p0=[1])
        
        d1_chisq = red_chi_sq(constant_fit_fun_1(times, *par1), 
                              fluence_D1.value, fluence_D1_err.value, dof=len(fluence_D1) - 1)
        d2_chisq = red_chi_sq(constant_fit_fun_2(times, *par2), 
                              fluence_D2.value, fluence_D2_err.value, dof=len(fluence_D2) - 1)
        ratio_chisq = red_chi_sq(constant_fit_fun_ratio(times, *parr), 
                              ratio.value, ratio_err.value, dof=len(ratio) - 1)
        new_row = {'fit_model': lc_file, 'D1': par1, 'D2': par2, 'ratio': parr, 
                   'D1_chisq': d1_chisq, 'D2_chisq': d2_chisq, 'ratio_chisq': ratio_chisq,
                   'D1_chisq_nofit': d1_chisq_nofit, 'D2_chisq_nofit': d2_chisq_nofit, 'ratio_chisq_nofit': ratio_chisq_nofit,
                   'ngood': len(fluence_D1)}
        solutions.add_row({**new_row, **additional_info_for_table})
        if debug:
            fine_times = np.linspace(times[0], times[-1], 1000)
            alpha = 0.3
            lw = 0.5
            if lc_file.replace('.dat', '') == correct_model.replace('.dat', ''):
                alpha, lw = 1, 1
                
            ax1.plot(fine_times, 
                     constant_fit_fun_1(fine_times, *par1), color='k', alpha=alpha, lw=lw, markersize=0)
            ax2.plot(fine_times, 
                     constant_fit_fun_2(fine_times, *par2), color='k', alpha=alpha, lw=lw, markersize=0)
            axr.plot(fine_times, 
                     constant_fit_fun_ratio(fine_times, *parr), color='k', alpha=alpha, lw=lw, markersize=0)
            if correct_model in lc_file:
                ax1.plot(fine_times, 
                         constant_fit_fun_1(fine_times, 1), label=correct_model, color='b', lw=1, markersize=0)
                ax2.plot(fine_times, 
                         constant_fit_fun_2(fine_times, 1), label=correct_model, color='b', lw=1, markersize=0)
                axr.plot(fine_times, 
                         constant_fit_fun_ratio(fine_times, 1), label=correct_model, color='b', lw=1, markersize=0)
                

    if debug:
        axr.set_ylabel('Flux ratio')
        ax1.legend()
        print("Debug done")
        plt.show()

    return solutions

In [None]:
def simulate_and_fit(model, observing_windows = np.array([[0, 86400 * 5]]) * u.s, distance = 200 * u.Mpc):
    s = Simulations()
    
    lightcurve_init = \
        get_lightcurve(model, exposure=100*u.s,  
                       observing_windows=observing_windows,
                       distance=distance)
    lightcurve_rebin = lightcurve_through_image(lightcurve_init, exposure=300*u.s, 
                                                final_resolution=1200*u.s, 
                                                silent=True)

    solutions = fit_lightcurve(lightcurve_rebin, label=f'{model}, {distance.to(u.Mpc).value} Mpc', 
                               debug=True, 
                               fit_model_files=s.sne_bsg_simulations + s.sne_rsg_simulations + s.sne_ysg_simulations,
                               correct_model=model, distance=distance, psf_correction = 0.937)
    
simulate_and_fit('bsg80')

In [None]:
simulate_and_fit('rsg600', distance = 600 * u.Mpc)

In [None]:
simulate_and_fit('rsg600', distance = 500 * u.Mpc)

In [None]:
simulate_and_fit('ysg400')

In [None]:
lightcurve_init = \
    get_lightcurve('kilonova_0.04.dat', exposure=300*u.s,  
                   observing_windows=np.array([[1800, 40000]]) * u.s,
                   distance=100*u.Mpc)
lightcurve_rebin = lightcurve_through_image(lightcurve_init, exposure=300*u.s, 
                                            final_resolution=6000*u.s, 
                                            silent=True)

solutions = fit_lightcurve(lightcurve_rebin, label='kilonova_0.04.dat - 100 Mpc', 
                       debug=True)


In [None]:
def get_rebinned_lightcurve_fit(input_lc_file, exposure, label=None, debug=False,  
                            observing_windows=np.array([[1800, 30000]]) * u.s, 
                            final_resolution=1200 * u.s, distance=100e6*u.pc,
                            ntrial=100, outfile=None):
    import seaborn as sns
    import tqdm
    from astroduet.utils import suppress_stdout
    if outfile is not None and os.path.exists(outfile):
        solutions = Table.read(outfile)
    else:
        solutions = None
    for i in tqdm.tqdm(range(ntrial)):
        try:
            with suppress_stdout():
                lightcurve_init = \
                    get_lightcurve(input_lc_file, exposure=exposure, observing_windows=observing_windows,
                                   distance=distance)

                lightcurve_rebin = lightcurve_through_image(lightcurve_init, exposure=exposure, 
                                                                final_resolution=final_resolution, 
                                                                debug=debug, silent=True)
        
            solutions = fit_lightcurve(lightcurve_rebin, label=input_lc_file, solutions=solutions,
                                       debug=False)
        except Exception as e:
            print("An exception occurred. Intermediate results are saved in the solution Table")
            print(e)
            break
            
    if outfile is not None:
        solutions.write(outfile, overwrite=True)

    return solutions

------**Uncomment below to produce the data**------

In [None]:
# solutions_sh510_30 = get_rebinned_lightcurve_fit('shock_5e10.dat', exposure=300 * u.s,
#                         observing_windows=np.array([[1800, 30000]]) * u.s,
#                         distance=200 * u.Mpc, ntrial=100)

# sns.pairplot(solutions_sh510_30.to_pandas(), hue='model', diag_kind="kde", vars='D1_chisq,D2_chisq,ratio_chisq'.split(','))
# solutions_sh510_30.write('solutions_sh510.csv', overwrite=True)

Figure description: The fits with low chi squared are systematically those to the correct model (in this case, a shock-type GRB). We generated 30 light curves corresponding to the model `shock_5e10`, at 200 Mpc, starting 30 minutes after the event, and including all instrumental and zodiacal noise sources, and fitted it with all six GW models. The best-fit on the D1 and D2 light curve (as measured from low values of $\chi^2$) is systematically the one corresponding to the correct model.

In [None]:
ntrial = 10
while 1:
    solutions_sh510 = get_rebinned_lightcurve_fit('shock_5e10.dat', exposure=300 * u.s,
                        observing_windows=np.array([[1800, 30000]]) * u.s,
                        distance=200 * u.Mpc, ntrial=ntrial, outfile='solutions_sh510.csv')

    solutions_k04 = get_rebinned_lightcurve_fit('kilonova_0.04.dat', exposure=300 * u.s,
                            observing_windows=np.array([[1800, 40000]]) * u.s,
                            final_resolution=6000 * u.s,
                           distance=130*u.Mpc, ntrial=ntrial, outfile='solutions_k04.csv')    

In [None]:
sns.pairplot(solutions_sh510.to_pandas(), hue='model', diag_kind="kde", vars='D1_chisq,D2_chisq,ratio_chisq'.split(','))
sns.pairplot(solutions_k04.to_pandas(), hue='model', diag_kind="kde", vars='D1_chisq,D2_chisq,ratio_chisq'.split(','))


Figure description: Same as previous figure, but this time we simulated 30 lightcurves corresponding to the model `kilonova_0.04`, at 150 Mpc, starting 30 minutes after the event, including all instrumental and zodiacal noise sources, and fitted it with all six GW models. The best-fit on the D1 and D2 light curve (as measured from low values of $\chi^2$), in this case, separates kilonova models from shock GRB models but not much different kilonova models.

In [None]:
def plot_distributions(solutions, correct, label=None, nbins=10):
    plt.figure(figsize=(15,15))
    plt.suptitle(label)
    gs = plt.GridSpec(4, 1, height_ratios=(4, 3, 4, 3), hspace=0)
    ax11 = plt.subplot(gs[0])
    ax12 = plt.subplot(gs[1], sharex=ax11)
    ax21 = plt.subplot(gs[2], sharex=ax11)
    ax22 = plt.subplot(gs[3], sharex=ax11)
    good = (solutions['D1_chisq'] < 1e32)&(solutions['D2_chisq'] < 1e32)&(solutions['ngood'] >= 4)
    solutions = solutions[good]

    if correct is not None:
        ax1 = ax11
        ax2 = ax21
        good = solutions['model'] == correct
        sol = solutions[good]
        per_99_1 = np.percentile(sol['D1_chisq'], 99)
        per_99_2 = np.percentile(sol['D2_chisq'], 99)
        ax1.hist(sol['D1_chisq'], label=label, alpha=1, density=True, 
                 bins=np.linspace(np.min(sol['D1_chisq']), np.max(sol['D1_chisq']), nbins))
        ax2.hist(sol['D2_chisq'], label=label, alpha=1, density=True, 
                 bins=np.linspace(np.min(sol['D2_chisq']), np.max(sol['D2_chisq']), nbins))
        ax1.axvline(per_99_1, ls='--', lw=3, color='b')
        ax2.axvline(per_99_2, ls='--', lw=3, color='b')

    for label in sorted(list(set(solutions['model']))):
        good = solutions['model'] == label
        sol = solutions[good]
        alpha=0.4
        ax1 = ax12
        ax2 = ax22
        if correct is not None and label == correct:
            continue
        rej = np.count_nonzero(sol['D1_chisq'] > per_99_1) / len(sol['D1_chisq'])
        ax1.hist(sol['D1_chisq'], label=label + f' rej. {rej*100:.0f}%', alpha=alpha, density=True, 
                 bins=np.linspace(np.min(sol['D1_chisq']), np.max(sol['D1_chisq']), nbins))
#                  bins=np.logspace(np.log10(np.min(sol['D1_chisq'])), np.log10(np.max(sol['D1_chisq'])), 10))
        rej = np.count_nonzero(sol['D2_chisq'] > per_99_2) / len(sol['D2_chisq'])
        ax2.hist(sol['D2_chisq'], label=label + f' rej. {rej*100:.0f}%', alpha=alpha, density=True, 
                 bins=np.linspace(np.min(sol['D2_chisq']), np.max(sol['D2_chisq']), nbins))
#                  bins=np.logspace(np.log10(np.min(sol['D2_chisq'])), np.log10(np.max(sol['D2_chisq'])), 10))
#         axr.hist(sol['ratio_chisq'], label=label, alpha=alpha, density=True)
    ax12.axvline(per_99_1, ls='--', lw=3, color='b', label='99% percentile')
    ax22.axvline(per_99_2, ls='--', lw=3, color='b', label='99% percentile')

    for ax in [ax11, ax12, ax21]:
        ax.xaxis.set_visible(False)

    for ax in [ax11, ax12, ax21, ax22]:
#         ax.loglog()
        ax.legend()
        ax.semilogx()
        ax.axvline(1, ls='--', lw=3, color='k', alpha=0.5)
        ax.set_ylabel('Hist. Density')
    ax22.set_xlabel(r'$\chi^2_{\rm red}$ (~1 indicates good fit)')
    ax11.set_xlim([1, None])
    return

def plot_distributions_double(solutions, correct, label=None, nbins=10):
    plt.figure(figsize=(15,15))
    plt.suptitle(label)
    gs = plt.GridSpec(4, 2, height_ratios=(4, 3, 4, 3), width_ratios=(1, 3), wspace=0.1, hspace=0)
    for i in [0, 1]:
        ax11 = plt.subplot(gs[0, i])
        ax12 = plt.subplot(gs[1, i], sharex=ax11)
        ax21 = plt.subplot(gs[2, i], sharex=ax11)
        ax22 = plt.subplot(gs[3, i], sharex=ax11)
        good = (solutions['D1_chisq'] < 1e32)&(solutions['D2_chisq'] < 1e32)&(solutions['ngood'] >= 4)
        solutions = solutions[good]

        if correct is not None:
            ax1 = ax11
            ax2 = ax21
            good = solutions['model'] == correct
            sol = solutions[good]
            per_99_1 = np.percentile(sol['D1_chisq'], 99)
            per_99_2 = np.percentile(sol['D2_chisq'], 99)
            ax1.hist(sol['D1_chisq'], label=label, alpha=1, density=True, 
                     bins=np.linspace(np.min(sol['D1_chisq']), np.max(sol['D1_chisq']), nbins))
            ax2.hist(sol['D2_chisq'], label=label, alpha=1, density=True, 
                     bins=np.linspace(np.min(sol['D2_chisq']), np.max(sol['D2_chisq']), nbins))
            ax1.axvline(per_99_1, ls='--', lw=3, color='b')
            ax2.axvline(per_99_2, ls='--', lw=3, color='b')

        for label in sorted(list(set(solutions['model']))):
            good = solutions['model'] == label
            sol = solutions[good]
            alpha=0.4
            ax1 = ax12
            ax2 = ax22
            if correct is not None and label == correct:
                continue
            rej = np.count_nonzero(sol['D1_chisq'] > per_99_1) / len(sol['D1_chisq'])
            ax1.hist(sol['D1_chisq'], label=label + f' rej. {rej*100:.0f}%', alpha=alpha, density=True, 
                     bins=np.linspace(np.min(sol['D1_chisq']), np.max(sol['D1_chisq']), nbins))
    #                  bins=np.logspace(np.log10(np.min(sol['D1_chisq'])), np.log10(np.max(sol['D1_chisq'])), 10))
            rej = np.count_nonzero(sol['D2_chisq'] > per_99_2) / len(sol['D2_chisq'])
            ax2.hist(sol['D2_chisq'], label=label + f' rej. {rej*100:.0f}%', alpha=alpha, density=True, 
                     bins=np.linspace(np.min(sol['D2_chisq']), np.max(sol['D2_chisq']), nbins))
    #                  bins=np.logspace(np.log10(np.min(sol['D2_chisq'])), np.log10(np.max(sol['D2_chisq'])), 10))
    #         axr.hist(sol['ratio_chisq'], label=label, alpha=alpha, density=True)
        ax12.axvline(per_99_1, ls='--', lw=3, color='b', label='99% percentile')
        ax22.axvline(per_99_2, ls='--', lw=3, color='b', label='99% percentile')
    
        for ax in [ax11, ax12, ax21]:
            ax.xaxis.set_visible(False)

        for ax in [ax11, ax12, ax21, ax22]:
    #         ax.loglog()
            if i == 1:
                ax.legend()
                ax.semilogx()
            ax.axvline(1, ls='--', lw=3, color='k', alpha=0.5)
            if i == 0:
                ax.set_ylabel('Hist. Density')
        ax22.set_xlabel(r'$\chi^2_{\rm red}$ (~1 indicates good fit)')
        if i == 0:
            ax11.set_xlim([0, max(per_99_1 + 1, per_99_2 + 1, 5)])
        else:
            ax11.set_xlim([1, None])
    return

solutions_sh510_30 = Table.read('solutions_sh510.csv')
ntrials = len(solutions_sh510_30)//6
good = solutions_sh510_30['D1_chisq'] > 0
solutions_sh510_30 = solutions_sh510_30[good]
ngood = len(solutions_sh510_30)//6
plot_distributions(solutions_sh510_30, correct='shock_5e10.dat', 
                   label=f"Shock 5e10 -- 200 Mpc -- {ntrials} trials ({ntrials - ngood} invalid)")

In [None]:
solutions_k04 = Table.read('solutions_k04.csv')
ntrials = len(solutions_k04)//6
good = solutions_k04['D1_chisq'] > 0
solutions_k04 = solutions_k04[good]
ngood = len(solutions_k04)//6
plot_distributions(solutions_k04, correct='kilonova_0.04.dat', 
                   label=f"Kilonova 0.04 -- 130 Mpc -- {ntrials} trials ({ntrials - ngood} invalid)", nbins=7)