In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import glob
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt

import copy
from astropy.visualization import quantity_support
from astropy.table import Table, QTable
import matplotlib
import astropy.units as u
from astroduet.config import Telescope
from astroduet.background import background_pixel_rate
font = {'size'   : 22}
from astroduet.models import Simulations, fits_file, load_model_ABmag, load_model_fluence
matplotlib.rcParams.update({'font.size': 22})
from astroduet.lightcurve import get_lightcurve, lightcurve_through_image
import astroduet.image_utils
import seaborn as sns



In [None]:
def red_chi_sq(f, x, s, dof=None):
    if dof is None:
        dof = len(f) - 1
    return np.sum((f - x)**2 / s**2) / dof


def fit_lightcurve(lightcurve, label='lightcurve fit', solutions=None,
                   debug=False, additional_info_for_table=None):
    from astroduet.models import Simulations
    from scipy.optimize import curve_fit
    from scipy.interpolate import interp1d
    from astropy.visualization import quantity_support
    quantity_support()
    if additional_info_for_table is None:
        additional_info_for_table = {}

    if solutions is None:
        names = 'fit_model,D1,D2,ratio,D1_chisq,D2_chisq,ratio_chisq,ngood'.split(',') + \
            [k for k, v in additional_info_for_table.items()]
        dtypes = dtype=['U30', float, float, float, float, float, float, int] + \
            [np.asarray(v).dtype for v in additional_info_for_table.values()]
        solutions = QTable(
            names=names, 
            dtype=dtypes)
    
    lc_files = Simulations().emgw_simulations

    fluence_D1 = lightcurve['fluence_D1_fit']
    fluence_D2 = lightcurve['fluence_D2_fit']
    snr_D1 = lightcurve['snr_D1']
    snr_D2 = lightcurve['snr_D2']
    good = (fluence_D1 > 0)&(fluence_D2 > 0)&(snr_D1 > 5)&(snr_D1 > 5)
    
    lightcurve = lightcurve[good]
    if len(lightcurve) < 2:
        print("Lightcurve is invalid")
        for lc_file in lc_files:
            empty_row_dict = {'fit_model': lc_file, 'D1': 0, 'D2': 0, 'ratio': 0, 
                              'D1_chisq': -1, 'D2_chisq': -1, 'ratio_chisq': -1,
                              'ngood': 0}
            solutions.add_row({**empty_row_dict,**additional_info_for_table})
        return solutions
    
    fluence_D1 = lightcurve['fluence_D1_fit']
    fluence_D2 = lightcurve['fluence_D2_fit']
    times = lightcurve['time']
    fluence_D1 = lightcurve['fluence_D1_fit']
    fluence_D2 = lightcurve['fluence_D2_fit']
    fluence_D1_err = lightcurve['fluence_D1_fiterr'] 
    fluence_D2_err = lightcurve['fluence_D2_fiterr'] 

    ratio = fluence_D2 / fluence_D1
    ratio_err = ratio * (fluence_D1_err / fluence_D1 +
                         fluence_D2_err / fluence_D2)

    if debug:
        plt.figure(figsize=(15, 15))
        plt.suptitle(label)
        gs = plt.GridSpec(3, 1)
        ax1 = plt.subplot(gs[0])
        ax2 = plt.subplot(gs[1], sharex=ax1)
        axr = plt.subplot(gs[2], sharex=ax1)
        ax1.errorbar(times, fluence_D1, yerr=fluence_D1_err, fmt='o', markersize=5)
        ax2.errorbar(times, fluence_D2, yerr=fluence_D2_err, fmt='o', markersize=5)
        axr.errorbar(times, ratio, yerr=ratio_err, fmt='o', markersize=5)

    for lc_file in lc_files:
        model_lc_table_fl = QTable(load_model_fluence(lc_file))
        interpolated_lc_1 = interp1d(model_lc_table_fl['time'].to(u.s).value,
                           model_lc_table_fl['fluence_D1'].value, fill_value=0,
                           bounds_error=False)
        interpolated_lc_2 = interp1d(model_lc_table_fl['time'].to(u.s).value,
                           model_lc_table_fl['fluence_D2'].value, fill_value=0,
                           bounds_error=False)
        def interpolated_lc_ratio(time):
            return(interpolated_lc_2(time) / interpolated_lc_1(time))
    
        def constant_fit_fun_1(x, a):
            return a * interpolated_lc_1(x)
        def constant_fit_fun_2(x, a):
            return a * interpolated_lc_2(x)
        def constant_fit_fun_ratio(x, a):
            return a * interpolated_lc_ratio(x)
        
        par1, pcov1 = curve_fit(constant_fit_fun_1, 
                                times, fluence_D1, 
                                sigma=fluence_D1_err, p0=[1])
        par2, pcov2 = curve_fit(constant_fit_fun_2, 
                                times, fluence_D2, 
                                sigma=fluence_D2_err, p0=[1])
        parr, pcovr = curve_fit(constant_fit_fun_ratio, 
                                times, ratio, sigma=ratio_err, p0=[1])
        
        d1_chisq = red_chi_sq(constant_fit_fun_1(times, *par1), 
                              fluence_D1.value, fluence_D1_err.value, dof=len(fluence_D1) - 1)
        d2_chisq = red_chi_sq(constant_fit_fun_2(times, *par2), 
                              fluence_D2.value, fluence_D2_err.value, dof=len(fluence_D2) - 1)
        ratio_chisq = red_chi_sq(constant_fit_fun_ratio(times, *parr), 
                              ratio.value, ratio_err.value, dof=len(ratio) - 1)
        new_row = {'fit_model': lc_file, 'D1': par1, 'D2': par2, 'ratio': parr, 
                   'D1_chisq': d1_chisq, 'D2_chisq': d2_chisq, 'ratio_chisq': ratio_chisq,
                   'ngood': len(fluence_D1)}
        solutions.add_row({**new_row, **additional_info_for_table})
        if debug:
            fine_times = np.linspace(times[0], times[-1], 1000)
            ax1.plot(fine_times, 
                     constant_fit_fun_1(fine_times, *par1), label=lc_file)
            ax2.plot(fine_times, 
                     constant_fit_fun_2(fine_times, *par2), label=lc_file)
            axr.plot(fine_times, 
                     constant_fit_fun_ratio(fine_times, *parr), label=lc_file)
    if debug:
        axr.set_ylabel('Flux ratio')
        ax1.legend()

    return solutions


In [None]:
Simulations().emgw_simulations

In [None]:
def simulate_and_fit(label=None, debug=False, exposure=300 * u.s,
                     start = 1800* u.s, stop=30000 * u.s, 
                     final_resolution=1200 * u.s, distances=[100 * u.Mpc, 200 * u.Mpc],
                     ntrial=100, outfile=None, allowed_sims=None):
    import seaborn as sns
    import tqdm
    from astroduet.utils import suppress_stdout
    nwrite = 20
    if allowed_sims is None:
        allowed_sims = Simulations().emgw_simulations
    if outfile is not None and os.path.exists(outfile):
        solutions = Table.read(outfile)
    else:
        outfile = 'monte_carlo.csv'
        solutions = None

    for i in tqdm.tqdm(range(ntrial)):
        observing_windows = np.array([[start.value, stop.value]]) * start.unit
        distance = np.random.uniform(distances[0].value, distances[1].value) * distances[0].unit
        
        input_lc_file = np.random.choice(allowed_sims)
        galaxy_type="none        "
        additional_info_for_table = {'model': input_lc_file, 'distance': distance,
                                     'start': start, 'end': stop, 'galaxy': galaxy_type,
                                     'final_resolution': final_resolution}
        try:
            with suppress_stdout():
                lightcurve_init = \
                    get_lightcurve(input_lc_file, exposure=exposure, 
                                   observing_windows=observing_windows,
                                   distance=distance)

                lightcurve_rebin = \
                    lightcurve_through_image(lightcurve_init, exposure=exposure, 
                                             final_resolution=final_resolution, 
                                             debug=debug, silent=True)
        
            solutions = fit_lightcurve(
                lightcurve_rebin, label=input_lc_file, solutions=solutions,
                debug=False, additional_info_for_table=additional_info_for_table)
            if outfile is not None and (i + 1) % nwrite == 0:
                solutions.write(outfile, overwrite=True)
        except Exception as e:
            print("An exception occurred. Intermediate results are saved in the solution Table")
            print(e)
            break
            

    return solutions

In [None]:
simulate_and_fit(ntrial=10000, outfile='monte_carlo.csv', distances=[50 * u.Mpc, 400 * u.Mpc])
