In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import glob
import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
from astropy import log

import copy
from astropy.visualization import quantity_support
from astropy.table import Table, QTable
import matplotlib
import astropy.units as u
from astroduet.config import Telescope
from astroduet.background import background_pixel_rate
font = {'size'   : 22}
from astroduet.models import Simulations, fits_file, load_model_ABmag, load_model_fluence
matplotlib.rcParams.update({'font.size': 22})
from astroduet.lightcurve import get_lightcurve, lightcurve_through_image
import astroduet.image_utils
import seaborn as sns



In [None]:
# s = Simulations()
# s.parse_sne_bsg()
# s.parse_sne_rsg()


In [None]:
def get_random_galaxy(distance, duet, tolerance=10 * u.Mpc):
    from astroduet.models import load_bai
    from astropy.table import QTable
    from astroduet.utils import galex_to_duet
    bai = QTable(load_bai())
    bai['BAI1'], bai['BAI2'] = galex_to_duet([bai['SURFFUV'].value, bai['SURFNUV'].value], duet=duet)
    oversample = 6
    pixel_size_init = duet.pixel / oversample
    galaxies_within_distance = \
        (bai['DIST'] >= distance - tolerance)&(bai['DIST'] < distance + tolerance)
    if not np.any(galaxies_within_distance):
        log.warn(f"No galaxies in BAI catalogue between "
                 f"{(distance - tolerance).to(u.Mpc).value:.2f} "
                 f"and {(distance + tolerance).to(u.Mpc).value:.2f} Mpc."
                 " Choosing randomly from sample > 150 Mpc")
        galaxies_within_distance = bai['DIST'] > 150 * u.Mpc
    
    galaxy = np.random.choice(bai[galaxies_within_distance])
    
    rad = (galaxy['RAD'] * u.arcsec * galaxy['DIST'] * u.Mpc / distance).to(u.arcsec)
    rad_pix = rad / pixel_size_init

    if galaxy['MORPH'] >=0: 
        n=1
    else: 
        n = 4
    theta = np.random.uniform(0, np.pi)
    ellip = np.random.uniform(0.1, 1)
    pos = np.random.uniform(0.1, 3)
    d = pos * rad / pixel_size_init
    x_0 = d * np.cos(theta)
    y_0 = d * np.sin(theta)
    
    gal_params = dict(r_eff=rad_pix.value, n=n, theta=theta, ellip=ellip, x_0=x_0.value, y_0=y_0.value)
    gal_params1 = {'magnitude': galaxy['BAI1'], **gal_params}
    gal_params2 = {'magnitude': galaxy['BAI2'], **gal_params}
    return gal_params1, gal_params2


In [None]:
def red_chi_sq(f, x, s, dof=None):
    if dof is None:
        dof = len(f) - 1
    return np.sum((f - x)**2 / s**2) / dof


def fit_lightcurve(lightcurve, label='lightcurve fit', solutions=None,
                   debug=False, additional_info_for_table=None,
                   fit_model_files=None, correct_model="", distance=None):
    from astroduet.models import Simulations
    from scipy.optimize import curve_fit
    from scipy.interpolate import interp1d
    from astropy.visualization import quantity_support
    quantity_support()
    if additional_info_for_table is None:
        additional_info_for_table = {}
    if distance is None:
        distance = 100 * u.Mpc
    if solutions is None:
        add_names = [k for k, v in additional_info_for_table.items()]
        add_dtypes = [np.asarray(v).dtype if 'model' not in k else 'U30' for k, v in additional_info_for_table.items()]
        names = 'fit_model,D1,D2,ratio,D1_chisq,D2_chisq,ratio_chisq,D1_chisq_nofit,D2_chisq_nofit,ratio_chisq_nofit,ngood'.split(',') + add_names
        dtypes = dtype=['U30', float, float, float, float, float, float, float, float, float, int] + add_dtypes
            
        solutions = QTable(
            names=names, 
            dtype=dtypes)
    
    if fit_model_files is not None:
        lc_files = fit_model_files
    else:
        lc_files = Simulations().emgw_simulations

    fluence_D1 = lightcurve['fluence_D1_fit']
    fluence_D2 = lightcurve['fluence_D2_fit']
    snr_D1 = lightcurve['snr_D1']
    snr_D2 = lightcurve['snr_D2']
    good = ((fluence_D1 > 0)&(snr_D1 > 5))|((fluence_D2 > 0)&(snr_D2 > 5))
    
    lightcurve = lightcurve[good]
    if len(lightcurve) < 5:
        print("Lightcurve is invalid")
        for lc_file in lc_files:
            empty_row_dict = {'fit_model': lc_file, 'D1': 0, 'D2': 0, 'ratio': 0, 
                              'D1_chisq': -1, 'D2_chisq': -1, 'ratio_chisq': -1,
                              'D1_chisq_nofit': -1, 'D2_chisq_nofit': -1, 'ratio_chisq_nofit': -1,
                              'ngood': 0}
            solutions.add_row({**empty_row_dict,**additional_info_for_table})
        return solutions
    
    fluence_D1 = lightcurve['fluence_D1_fit']
    fluence_D2 = lightcurve['fluence_D2_fit']
    times = lightcurve['time']
    fluence_D1 = lightcurve['fluence_D1_fit']
    fluence_D2 = lightcurve['fluence_D2_fit']
    fluence_D1_err = lightcurve['fluence_D1_fiterr'] 
    fluence_D2_err = lightcurve['fluence_D2_fiterr'] 

    ratio = fluence_D2 / fluence_D1
    ratio_err = ratio * (fluence_D1_err / fluence_D1 +
                         fluence_D2_err / fluence_D2)
    
    if debug:
        plt.figure(figsize=(15, 15))
        plt.suptitle(label)
        gs = plt.GridSpec(3, 1)
        ax1 = plt.subplot(gs[0])
        ax2 = plt.subplot(gs[1], sharex=ax1)
        axr = plt.subplot(gs[2], sharex=ax1)
        ax1.errorbar(times, fluence_D1, yerr=fluence_D1_err, fmt='o', markersize=5, label='data')
        ax2.errorbar(times, fluence_D2, yerr=fluence_D2_err, fmt='o', markersize=5, label='data')
        axr.errorbar(times, ratio, yerr=ratio_err, fmt='o', markersize=5, label='data')

    for lc_file in lc_files[::]:
        model_lc_table_fl = QTable(load_model_fluence(lc_file, dist=distance))
        interpolated_lc_1 = interp1d(model_lc_table_fl['time'].to(u.s).value,
                           model_lc_table_fl['fluence_D1'].value, fill_value=0,
                           bounds_error=False)
        interpolated_lc_2 = interp1d(model_lc_table_fl['time'].to(u.s).value,
                           model_lc_table_fl['fluence_D2'].value, fill_value=0,
                           bounds_error=False)
        def interpolated_lc_ratio(time):
            return(interpolated_lc_2(time) / interpolated_lc_1(time))
    
        def constant_fit_fun_1(x, a):
            return a * interpolated_lc_1(x)
        def constant_fit_fun_2(x, a):
            return a * interpolated_lc_2(x)
        def constant_fit_fun_ratio(x, a):
            return a * interpolated_lc_ratio(x)
        
        d1_chisq_nofit = red_chi_sq(constant_fit_fun_1(times, 1), 
                              fluence_D1.value, fluence_D1_err.value, dof=len(fluence_D1) - 1)
        d2_chisq_nofit = red_chi_sq(constant_fit_fun_2(times, 1), 
                              fluence_D2.value, fluence_D2_err.value, dof=len(fluence_D2) - 1)
        ratio_chisq_nofit = red_chi_sq(constant_fit_fun_ratio(times, 1), 
                              ratio.value, ratio_err.value, dof=len(ratio) - 1)

        par1, pcov1 = curve_fit(constant_fit_fun_1, 
                                times, fluence_D1, 
                                sigma=fluence_D1_err, p0=[1])
        par2, pcov2 = curve_fit(constant_fit_fun_2, 
                                times, fluence_D2, 
                                sigma=fluence_D2_err, p0=[1])
        parr, pcovr = curve_fit(constant_fit_fun_ratio, 
                                times, ratio, sigma=ratio_err, p0=[1])
        
        d1_chisq = red_chi_sq(constant_fit_fun_1(times, *par1), 
                              fluence_D1.value, fluence_D1_err.value, dof=len(fluence_D1) - 1)
        d2_chisq = red_chi_sq(constant_fit_fun_2(times, *par2), 
                              fluence_D2.value, fluence_D2_err.value, dof=len(fluence_D2) - 1)
        ratio_chisq = red_chi_sq(constant_fit_fun_ratio(times, *parr), 
                              ratio.value, ratio_err.value, dof=len(ratio) - 1)
        new_row = {'fit_model': lc_file, 'D1': par1, 'D2': par2, 'ratio': parr, 
                   'D1_chisq': d1_chisq, 'D2_chisq': d2_chisq, 'ratio_chisq': ratio_chisq,
                   'D1_chisq_nofit': d1_chisq_nofit, 'D2_chisq_nofit': d2_chisq_nofit, 'ratio_chisq_nofit': ratio_chisq_nofit,
                   'ngood': len(fluence_D1)}
        solutions.add_row({**new_row, **additional_info_for_table})
        if debug:
            fine_times = np.linspace(times[0], times[-1], 1000)
            alpha = 0.3
            lw = 0.5
            if lc_file.replace('.dat', '') == correct_model.replace('.dat', ''):
                alpha, lw = 1, 1
                
            ax1.plot(fine_times, 
                     constant_fit_fun_1(fine_times, *par1), label=lc_file, alpha=alpha, lw=lw, markersize=0)
            ax2.plot(fine_times, 
                     constant_fit_fun_2(fine_times, *par2), label=lc_file, alpha=alpha, lw=lw, markersize=0)
            axr.plot(fine_times, 
                     constant_fit_fun_ratio(fine_times, *parr), label=lc_file, alpha=alpha, lw=lw, markersize=0)
            if correct_model in lc_file:
                ax1.plot(fine_times, 
                         constant_fit_fun_1(fine_times, 1), label=correct_model, color='b', lw=1, markersize=0)
                ax2.plot(fine_times, 
                         constant_fit_fun_2(fine_times, 1), label=correct_model, color='b', lw=1, markersize=0)
                axr.plot(fine_times, 
                         constant_fit_fun_ratio(fine_times, 1), label=correct_model, color='b', lw=1, markersize=0)
                

    if debug:
        axr.set_ylabel('Flux ratio')
        ax1.legend()
        print("Debug done")
        plt.show()

    return solutions


In [None]:
def simulate_and_fit(label=None, debug=False, exposure=300 * u.s,
                     observing_windows=[[1800, 30000]] * u.s, 
                     final_resolution=1200 * u.s, distances=[100 * u.Mpc, 200 * u.Mpc],
                     ntrial=100, outfile=None, allowed_sims=None, allowed_fit_sims=None,
                     galaxy_type="none", nwrite=20):
    import sys
    import seaborn as sns
    import tqdm
    from astroduet.utils import suppress_stdout
    import traceback
    
    start = np.min(observing_windows.value) * observing_windows.unit
    stop = np.max(observing_windows.value) * observing_windows.unit
    
    if allowed_sims is None:
        allowed_sims = Simulations().emgw_simulations
        
    if allowed_fit_sims is None:
        allowed_fit_sims = allowed_sims
        
    solutions = None
    if outfile is not None and os.path.exists(outfile):
        solutions = Table.read(outfile)
    elif outfile is None:
        outfile = 'monte_carlo.csv'
    galaxy_parameters = get_random_galaxy(50 * u.Mpc, Telescope())[0]
    for k, v in galaxy_parameters.items():
        galaxy_parameters[k] = 0.

    for i in tqdm.tqdm(range(ntrial)):
        
        distance = np.random.uniform(distances[0].value, distances[1].value) * distances[0].unit
        
        input_lc_file = np.random.choice(allowed_sims)
        if galaxy_type == "ran1":
            galaxy_parameters = get_random_galaxy(distance, Telescope())[0]
        
        additional_info_for_table = {'model': input_lc_file, 'distance': distance,
                                     'start': start, 'end': stop, 'galaxy': galaxy_type,
                                     'final_resolution': final_resolution, **galaxy_parameters}
        gal_type = gal_params = None
        if galaxy_type != "none":
            gal_type = "custom"
            gal_params = galaxy_parameters
        
        try:
            with suppress_stdout():
                lightcurve_init = \
                    get_lightcurve(input_lc_file, exposure=exposure, 
                                   observing_windows=observing_windows,
                                   distance=distance)

                lightcurve_rebin = \
                    lightcurve_through_image(lightcurve_init, exposure=exposure, 
                                             final_resolution=final_resolution, 
                                             debug=debug, silent=True,
                                             gal_type=gal_type, gal_params=gal_params)
        
            solutions = fit_lightcurve(
                lightcurve_rebin, label=input_lc_file, solutions=solutions,
                debug=debug, additional_info_for_table=additional_info_for_table, 
                fit_model_files=allowed_fit_sims, correct_model=input_lc_file,
                distance=distance)
            if outfile is not None and (i + 1) % nwrite == 0:
                solutions.write(outfile, overwrite=True)
        except Exception as e:
            print("An exception occurred. Intermediate results are saved in the solution Table")
            exc_type, exc_value, exc_traceback = sys.exc_info()

            traceback.print_tb(exc_traceback)
            if outfile is not None and solutions is not None and len(solutions) > 1:
                solutions.write(outfile + '+_recovered.csv', overwrite=True)
            break
            
    if outfile is not None and solutions is not None and len(solutions) > 1:
        solutions.write(outfile, overwrite=True)

    return solutions

In [None]:
# # simulate_and_fit(ntrial=10000, outfile='monte_carlo.csv', distances=[50 * u.Mpc, 400 * u.Mpc])
# simulate_and_fit(ntrial=1, outfile='monte_carlo_nofit.csv', distances=[10 * u.Mpc, 300 * u.Mpc], 
#                          allowed_sims=['bsg20', 'bsg80'],
#                          allowed_fit_sims=s.sne_bsg_simulations + s.sne_rsg_simulations, 
#                          observing_windows = observing_windows[:1], final_resolution=1200 * u.s,
#                          galaxy_type=galaxy_type, nwrite=1)

In [None]:
all_fit_sims = s.sne_bsg_simulations + s.sne_rsg_simulations + s.sne_ysg_simulations

while 1:
    for galaxy_type in ["ran1", "none"]:
        simulate_and_fit(ntrial=4, outfile='monte_carlo_nofit.csv', distances=[10 * u.Mpc, 300 * u.Mpc], 
                         allowed_sims=['bsg20', 'bsg80'],
                         allowed_fit_sims=all_fit_sims, 
                         observing_windows = np.array([[0, 430000]]) * u.s, final_resolution=1200 * u.s,
                         galaxy_type=galaxy_type, nwrite=1)

        simulate_and_fit(ntrial=4, outfile='monte_carlo_nofit.csv', distances=[10 * u.Mpc, 800 * u.Mpc], 
                     allowed_sims=['ysg150', 'ysg400'],
                     allowed_fit_sims=all_fit_sims, 
                     observing_windows = np.array([[0, 430000]]) * u.s, final_resolution=1200 * u.s,
                     galaxy_type=galaxy_type, nwrite=1)

        simulate_and_fit(ntrial=4, outfile='monte_carlo_nofit.csv', distances=[10 * u.Mpc, 800 * u.Mpc], 
                     allowed_sims=['rsg400', 'rsg600'],
                     allowed_fit_sims=all_fit_sims, 
                     observing_windows = np.array([[0, 430000]]) * u.s, final_resolution=1200 * u.s,
                     galaxy_type=galaxy_type, nwrite=1)
