In [None]:
%matplotlib notebook

import numpy as np
import matplotlib.pyplot as plt
import os
from astropy.io import ascii
from astropy.table import Table,Column,vstack
import iminuit
from iminuit import Minuit,describe
from iminuit.util import make_func_code
from matplotlib.pyplot import cm 
from collections import OrderedDict

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
# Configuration file used all along this notebook
configFile=os.getenv('ImSimpy_DIR')+'/ImSimpy/configFiles/GRB091020.hjson'

# define the telescope to use 
name_telescope=['colibri','VT']

# Path to data
path2data='data/GRB091020/'
path2data_spline='data/GRB091020_spline/'

# Name of output_dir. Will be used for creating folder in catalog/, images/, etc...
output_dir='GRB091020'

# GRBs to study
GRB='091020'

# Bands to consider. (Should have at least 2 detections)
bands =['Blc','Vlc', 'Rlc']

# Set up color map
cmap = cm.rainbow(np.linspace(0,1,len(bands)))[::-1]
band=bands[0]


In [None]:
# Load data and spline data
data=ascii.read(path2data+band+GRB+'.txt', delimiter =';', data_start=1)
data_spline=ascii.read(path2data_spline+band+GRB+'_spline.txt', delimiter ='\t', data_start=1)

In [None]:
# Define standard Zeropoints for transformation magnitude to Jansky
def ZP(band):
    if band in ['U',' U']: 
        ZP= 1790
    elif band in ['B',' B']: 
        ZP= 4063    
    elif band in ['V','v', ' V']: 
        ZP= 3631
    elif band in ['R','Rc', ' Rc','CR',' R', 'un', ' un']: 
        ZP= 3064
    elif band in ['r', ' r', "r'", " r' ",'i', ' i', "i'", " i'", "g'", " g'", 'g', ' g']:
        ZP=3631
    elif band in ['I','CI','Ic',' I']: 
        ZP= 2416
    elif band in ['z',' z']: 
        ZP= 3631
    elif band in ['J',' J']: 
        ZP= 1589
    elif band in ['H', ' H']: 
        ZP= 1021
    elif band in ['K','Ks', ' K']: 
        ZP= 640
    else: 
        print ('band %s not found' % band)
        ZP=None
    return ZP

# Define effective wavelenght of main filter bands
def eff_wvl(band):
    if band in ['U', ' U']: 
        eff_wvl= 3650
    elif band in ['B', ' B']: 
        eff_wvl= 4400
    elif band in ['V','v', ' V']: 
        eff_wvl= 5468
    elif band in ['R','Rc','CR','r',' R', "r'", " r'"]: 
        eff_wvl= 6580
    elif band in ['I','CI','i','Ic',' I']: 
        eff_wvl= 8060
    elif band in ['z',' z', "z'", " z'"]: 
        eff_wvl= 8917
    elif band in ['J',' J']: 
        eff_wvl= 12200
    elif band in ['H', ' H']: 
        eff_wvl= 16300
    elif band in ['K','Ks', ' K']: 
        eff_wvl= 21900
    else: 
        print ('band %s not found' % band)
        eff_wvl=None
    return eff_wvl

In [None]:
# Plot Light curves using data and spline data

plt.figure()
for i,band in enumerate(bands[::-1]):
    data=ascii.read(path2data+band+GRB+'.txt', delimiter =';', data_start=1)    
    data_spline=ascii.read(path2data_spline+band+GRB+'_spline.txt', delimiter ='\t', data_start=1)
    
    # Convert minutes into s
    mask = data['unit_time'] =='m'
    data['Tstart'][mask]*=60
    data['Tend'][mask]*=60

    # Convert hours into s
    mask = data['unit_time'] =='h'
    data['Tstart'][mask]*=3600
    data['Tend'][mask]*=3600

    # Convert days into s
    mask = data['unit_time'] =='d'
    data['Tstart'][mask]*=86400
    data['Tend'][mask]*=86400
    
    flux=[]
    for j in range(len(data['mag'])):
        flux.append(ZP(data['filter'][j])*1e3*10**(-0.4*(data['mag'][j])))
    flux=np.array(flux)
    flux_err = flux * 0.4 *np.log(10)*data['err_mag']
    time_err=(data['Tend']-data['Tstart'])
    
    plt.plot(data_spline['mid_time (s)'], data_spline['flux [erg.cm-2.s-1 (X) or mJy (opt/IR, radio)]'], label = band[0], color=cmap[i])
    plt.errorbar((data['Tstart']+data['Tend'])/2,flux,xerr=time_err,yerr=flux_err ,fmt='o', color=cmap[i],markersize=2)
    plt.xscale('log')
    plt.yscale('log')
    plt.xlabel('Time since triger [s]',size=17)
    plt.ylabel('Flux [mJy]',size=17)
    #plt.grid(True)
    plt.grid(True,alpha=0.4,ls='--')
    plt.title('GRB %s' % GRB,size=18)
    plt.tick_params(labelsize=15)
    plt.tight_layout()
plt.legend(fontsize=13)
plt.savefig('LC_%s_LC.png' % GRB)

In [None]:
# Need to extrapolate both in time and wavelength

In [None]:
# Create common file which includes the light curves at all wavelengths
folder = path2data
prefix='.txt'
data_table=None
for band in bands:
    fname = folder + band + GRB + prefix
    data_lc = np.genfromtxt(fname, delimiter=';', skip_header = 1,
                            dtype=['U7','U5',float,float,float,'U3',float,float,
                                    'U3',float,float,float],
                            names=['GRB_name','trigger_sat','redshift',
                                    'Tstart','Tend','time_unit','mag','mag_err',
                                    'filter','Flux_mJy','Flux_mJy_err','telescope_ID'])
            
    # create astropy table containing all the data
    if not data_table: data_table=Table(data_lc)
    else: data_table=vstack([data_table,Table(data_lc)],join_type=('outer'))

        
# Convert all time to seconds
mask = data_table['time_unit'] =='m'
data_table['Tstart'][mask]*=60
data_table['Tend'][mask]*=60

# Convert hours into s
mask = data_table['time_unit'] =='h'
data_table['Tstart'][mask]*=3600
data_table['Tend'][mask]*=3600

# Convert days into s
mask = data_table['time_unit'] =='d'
data_table['Tstart'][mask]*=86400
data_table['Tend'][mask]*=86400

data_table['time_unit']='s'
# Write table in txt file by ascending time
data_table.sort(['Tstart'])
#ascii.write(data_table,'Data_%s.txt' % GRB,overwrite=True)        

# Add effective wavelength and compute flux        
eff_wvl_list=[]
flux=[]
for data in data_table:
    eff_wvl_list.append(eff_wvl(data['filter']))
    flux.append(ZP(data['filter'])*1e3*10**(-0.4*(data['mag'])))

flux=np.array(flux)
# Convert mag in Jy
flux_err=abs(flux * -0.4 * np.log(10) * data_table['mag_err'])
# Magnitude associated to their mean observation time
time=(data_table['Tstart']+data_table['Tend'])/2

col_time=Column(name='Time',data=time,unit='s')
col_flux=Column(name='Flux',data=flux,unit='microJy')
col_flux_err=Column(name='Flux_err',data=flux_err,unit='microJy')
col_eff_wvl=Column(name='eff_wvl',data=eff_wvl_list,unit='Angstroms')

data_table.add_columns([col_eff_wvl,col_time,col_flux,col_flux_err])
mask = (data_table['filter'] != ' X') & (np.isfinite(data_table['Time']))
ascii.write(data_table['GRB_name','Time','filter','eff_wvl','Flux','Flux_err'][mask],
            'data/multi_LC_%s.txt' % GRB,overwrite=True) 

# Read spline interpolation
folder_spline = path2data_spline
prefix_spline='_spline.txt'
dataSpline_table=[]
    
for band in bands:
    fname = folder_spline + band + GRB + prefix_spline
    data_spline = np.genfromtxt(fname,delimiter='\t', skip_header = 1,
                                dtype=['U7',float,float,float,float,
                                        float,float,float,'U3'],
                                names=['GRB_name','redshift','Time','Tstart','Tend',
                                        'Flux','Flux_err','dataType','filter'])

    dataSpline_table.append(Table(data_spline))
    
dataSpline_table=vstack(dataSpline_table)
# Write table in txt file by ascending time
dataSpline_table.sort(['Tstart'])
#ascii.write(dataSpline_table,'Data_spline_%s.txt' % GRB,overwrite=True)

# Add effective wavelength
eff_wvl_list=[]

for data in dataSpline_table:
    #print (data['filter'])
    eff_wvl_list.append(eff_wvl(data['filter']))

col_eff_wvl=Column(name='eff_wvl',data=eff_wvl_list,unit='Angstroms')

dataSpline_table.add_columns([col_eff_wvl])
mask = (dataSpline_table['filter'] != ' X') & (np.isfinite(dataSpline_table['Time']))
ascii.write(dataSpline_table['GRB_name','Time','filter','eff_wvl','Flux','Flux_err'][mask],
            'data/multi_LC_%s_spline.txt' % GRB,overwrite=True) 

In [None]:
data_table

In [None]:
# Load data
observations_spline=ascii.read('data/multi_LC_%s_spline.txt' % GRB)
observations=ascii.read('data/multi_LC_%s.txt' % GRB)

In [None]:
observations

First fit each licht curve   

Then select a time at which to fit the SED

In [None]:
# Bunch of recipes to fit light curve and SED
def SPL_lc(t,F0,t0,norm,alpha):
    return norm *F0 * (t/t0)**(-alpha)

def BPL_lc(t,F0,norm,alpha1,alpha2,t1,s):
    return norm *F0 * ((t/t1)**(-s*alpha1) + (t/t1)**(-s*alpha2))**(-1./s)


def SPL_sed(wvl,F0,wvl0,norm,beta):
    return norm *F0* (wvl/wvl0)**beta

def BPL_sed(wvl,F0,norm,beta1,beta2,wvl1,s):
    return norm *F0 * ((wvl/wvl1)**(s*beta1) + (wvl/wvl1)**(s*beta2))**(1./s)

def template1(wvl,t,F0,wvl0,t0,norm,alpha,beta,z,Av,ext_law):
     return norm *F0 * (t/t0)**(-alpha) * (wvl/wvl0)**beta * sed_extinction(wvl,z,Av,ext_law=ext_law,Host_dust=Host_dust,Host_gas=Host_gas,MW_dust=MW_dust,MW_gas=MW_gas,DLA=DLA,igm_att=igm_att)

def template2(wvl,t,F0,wvl0,norm,alpha1,alpha2,t1,s,beta,z,Av,ext_law):
    Flux= norm *F0 * ((t/t1)**(-s*alpha1) + (t/t1)**(-s*alpha2))**(-1./s) * (wvl/wvl0)**beta * sed_extinction(wvl,z,Av,ext_law=ext_law,Host_dust=Host_dust,Host_gas=Host_gas,MW_dust=MW_dust,MW_gas=MW_gas,DLA=DLA,igm_att=igm_att)
    return Flux

def Flux_template(wvl,F0,wvl0,norm,beta,z,Av):
    Flux=SPL_sed(wvl,F0,wvl0,norm,beta) * sed_extinction(wvl,z,Av,ext_law='calzetti',Host_dust=True,Host_gas=False,MW_dust=False,MW_gas=False,DLA=False,igm_att='meiksin')
    return Flux

def Flare(t, F0, tmid, sigma):
    print (t)
    Flux=F0 * np.exp(-0.5 * ((t-tmid)/sigma)**2)

def BPL_lc_flare(t,F0,norm,alpha1,alpha2,t1,s, F0f,tmid,sigma):
    Flux=norm *F0 * ((t/t1)**(-s*alpha1) + (t/t1)**(-s*alpha2))**(-1./s) + F0f * np.exp(-0.5 * ((t-tmid)/sigma)**2)
    return Flux

def SPL_lc_flare(t,F0,t0,norm,alpha, F0f,tmid,sigma):
    Flux=norm *F0 * (t/t0)**(-alpha) + F0f * np.exp(-0.5 * ((t-tmid)/sigma)**2)
    return Flux

def SPL(wvl,t,F0,wvl0,t0,norm,beta,alpha):
    return norm *F0* (wvl/wvl0)**beta * (t/t0)**(-alpha)

def BPL(wvl,t,F0,wvl0,t0,norm,beta,alpha1,alpha2,s):
    return norm * F0*(wvl/wvl0)**beta * ((t/t0)**(-s*alpha1) + (t/t0)**(-s*alpha2))**(-1./s)

In [None]:
# Class to fit each light curve with iminuit
class Chi2Functor_lc:
   def __init__(self,f,t,y,yerr):
   #def __init__(self,f,wvl,y):
       self.f = f
       self.t = t
       self.y = y
       self.yerr = yerr
       f_sig = describe(f)
       #this is how you fake function 
       #signature dynamically
       self.func_code = make_func_code(f_sig[1:])#docking off independent variable
       self.func_defaults = None #this keeps np.vectorize happy
       #print (make_func_code(f_sig[1:]))
   def __call__(self,*arg):
       #notice that it accept variable length
       #positional arguments
       chi2 = sum(((y-self.f(t,*arg))**2/yerr**2) for t,y,yerr in zip(self.t,self.y,self.yerr))
       #chi2 = sum((y-self.f(wvl,*arg))**2 for wvl,y in zip(self.wvl,self.y))
       return chi2
    
    
# Class to fit a SED at a given time with iminuit
class Chi2Functor_sed:
   def __init__(self,f,x,y,yerr):
   #def __init__(self,f,wvl,y):
       self.f = f
       self.x = x
       self.y = y
       self.yerr = yerr
       f_sig = describe(f)
       #this is how you fake function 
       #signature dynamically
       self.func_code = make_func_code(f_sig[1:])#docking off independent variable
       self.func_defaults = None #this keeps np.vectorize happy
       #print (make_func_code(f_sig[1:]))
   def __call__(self,*arg):
       #notice that it accept variable length
       #positional arguments
       chi2 = sum(((y-self.f(x,*arg))**2/yerr**2) for x,y,yerr in zip(self.x,self.y,self.yerr))
       #chi2 = sum((y-self.f(wvl,*arg))**2 for wvl,y in zip(self.wvl,self.y))
       return chi2

In [None]:
def fit_lc(observations,spline,model,method='best'):
    """ Fit the lightcurve in order to get a flux and its uncertainty at each time
        The fit is performed for each band separetely
    """
    
    band_list=[] 
    grb_ref = []
    F0_list=[]
    norm_list=[]
    alpha_list=[]
    alpha1_list=[]
    alpha2_list=[]
    t1_list=[]
    t0_list=[]
    s_list=[]
    chi2_list=[]
    F0f_list =[]
    tmid_list=[]
    sigma_list=[]
    
    grb_name = GRB

    # Check whether it is a light curve or a sed
    z_sim = 1.71
    Av_sim = 0

    time_list=[]
    wvl_list=[]
    #for dat in observations.group_by(['Time']).groups.keys:
    #    time_list.append(dat[0])
    for dat in spline.group_by(['eff_wvl']).groups.keys:
        wvl_list.append(dat[0])
    
    for i,wvl in enumerate(wvl_list):
        mask_wvl = spline['eff_wvl']==wvl
        time = spline['Time'][mask_wvl]
        flux = spline['Flux'][mask_wvl]
        flux_err = np.ones(len(flux))
                
        
        # -------Guess initial values-----------
        time_F0 = 1e4
        #idx = (np.abs(time_F0 - time)).argmin()
        
        #F0_guess=flux[idx]
        F0_guess=flux[0]
        #print (F0_guess)
            
        # Search for extremum 
        #argrelextrema(, np.greater)
        idx=np.argmax(flux)
        if (idx < len(flux)-1) and (idx >0) :
                t1_guess=time[idx]
                limit_t1_guess=(0.1*t1_guess,10*t1_guess)
        else:
                idx=np.argmin(flux)
                if (idx>0) and (idx<len(flux)-1):
                    t1_guess=time[idx]
                    limit_t1_guess=(0.1*t1_guess,10*t1_guess)
                else:
                    t1_guess = time[0]
                    limit_t1_guess=(0,None)
        #t1_guess=lc_fit[np.argmax(lc_fit[:,2]),0]
        #print (t1_guess)
        norm_guess=1

        #t1_guess= time[idx]
        
        # Ugly ac to use same alpha for each bands
        if wvl!=6580:
            alpha =  1.16481928375
            fix_alpha=True
        else:
            alpha=1.16481928375
            fix_alpha=True
        
        fix_norm=False
        if wvl == 5468:
            #    mask = observations['filter']=='V'
            #    F0_guess=observations[mask]['Flux'][-1]
            #    t1_guess=observations[mask]['Time'][-1]
            F0_guess*=1.2
            fix_norm=True
        
        
        print (t1_guess,F0_guess)

        #alpha=1
        #fix_alpha=False
        
        if model == 'BPL':
                chi2_func = Chi2Functor_lc(BPL_lc,time,flux,flux_err)
                kwdarg = dict(pedantic=True,print_level=2,F0=F0_guess,
                              fix_F0=True,norm=norm_guess,fix_norm=False,limit_norm=(0.1,10),
                              alpha1=-0.5,limit_alpha1=[-10,10],alpha2=0.5,limit_alpha2=[-10,10],
                              t1=1100,fix_t1=False,limit_t1=[800,5000],s=3,limit_s=[0.01,20])

        elif model == 'SPL':
                chi2_func = Chi2Functor_lc(SPL_lc,time,flux,flux_err)
                kwdarg = dict(pedantic=True,print_level=2,F0=F0_guess,
                              fix_F0=True,norm=norm_guess,fix_norm=fix_norm,limit_norm=(0.1,10),
                              alpha=alpha,limit_alpha=[-10,10],fix_alpha=fix_alpha,t0=t1_guess,fix_t0=True,limit_t0=[0,None])
        elif model == 'BPL_flare':
                chi2_func = Chi2Functor_lc(BPL_lc_flare,time,flux,flux_err)
                kwdarg = dict(pedantic=True,print_level=2,F0=F0_guess,
                              fix_F0=True,norm=1,fix_norm=True,limit_norm=(0.1,10),
                              alpha1=1.6,limit_alpha1=[-10,10],alpha2=0.,fix_alpha2=True,limit_alpha2=[-10,10],
                              t1=1100,fix_t1=False,limit_t1=[800,5000],s=-1,fix_s=True,limit_s=[0.01,20],
                              F0f=F0_guess/2, limit_F0f=[0,F0_guess],fix_F0f=False, tmid=2000, fix_tmid=False,
                              limit_tmid=[0,None], sigma=1000,fix_sigma=False,
                              limit_sigma =[0,None])
        elif model == 'SPL_flare':
                chi2_func = Chi2Functor_lc(SPL_lc_flare,time,flux,flux_err)
                kwdarg = dict(pedantic=True,print_level=2,F0=F0_guess,
                              fix_F0=True,norm=norm_guess,fix_norm=False,limit_norm=(0.1,10),
                              alpha=1.6,limit_alpha=[-10,10],t0=t1_guess,
                              fix_t0=True,limit_t0=[0,None],
                              F0f=F0_guess/2, limit_F0f=[0,None],fix_F0f=False, tmid=2000, fix_tmid=False,
                              limit_tmid=[1000,None], sigma=500,fix_sigma=False,
                              limit_sigma =[200,None])
       
     
        #print (describe(chi2_func))
        else:
                sys.exit('Error: "%s" model for fitting the light curve unknown.\It should be either "BPL" or "SPL"' % model)

        m = Minuit(chi2_func,**kwdarg)
        m.set_strategy(1)
        #m.migrad(nsplit=1,precision=1e-10)
        d,l=m.migrad()
        #print (band)
        print ('Valid Minimum: %s ' % str(m.migrad_ok()))
        print ('Is the covariance matrix accurate: %s' % str(m.matrix_accurate()))


        grb_ref.append(grb_name)
        band_list.append(spline['filter'][mask_wvl][0])
        
        F0_list.append(m.values['F0'])
        norm_list.append(m.values['norm'])
        chi2_list.append(d.fval)
        if model == 'SPL':
                alpha_list.append(m.values['alpha'])
                t0_list.append(m.values['t0'])
        elif model == 'SPL_flare':
                alpha_list.append(m.values['alpha'])
                t0_list.append(m.values['t0'])
                F0f_list.append(m.values['F0f'])
                tmid_list.append(m.values['tmid'])
                sigma_list.append(m.values['sigma'])
        
        elif model == 'BPL':
                alpha1_list.append(m.values['alpha1'])
                alpha2_list.append(m.values['alpha2'])
                t1_list.append(m.values['t1'])
                s_list.append(m.values['s'])
        elif model == 'SPL_flare':
                alpha1_list.append(m.values['alpha1'])
                alpha2_list.append(m.values['alpha2'])
                t1_list.append(m.values['t1'])
                s_list.append(m.values['s'])
                F0f_list.append(m.values['F0f'])
                tmid_list.append(m.values['tmid'])
                sigma_list.append(m.values['sigma'])

    #if method == 'best':
            # If few points take the parameters of the best fit. It assumes achromatic evolution
            #best_fitted_band = obs_table['band'][np.argmax(obs_table['eff_wvl'])]

    #create astropy table as output
    if model == 'BPL': 
        lc_fit_params=Table([grb_ref,band_list,F0_list,norm_list,alpha1_list,alpha2_list,t1_list,s_list,chi2_list],
                            names=['name','band','F0','norm','alpha1','alpha2','t1','s','chi2'])
    elif model == 'SPL':     
        lc_fit_params=Table([grb_ref,band_list,F0_list,norm_list,alpha_list,t0_list,chi2_list],
                            names=['name','band','F0','norm','alpha','t0','chi2'])
    elif model == 'BPL_flare':     
        lc_fit_params=Table([grb_ref,band_list,F0_list,norm_list,alpha1_list,alpha2_list,t1_list,s_list,chi2_list,F0f_list,tmid_list,sigma_list],
                            names=['name','band','F0','norm','alpha1','alpha2','t1','s','chi2', 'F0f', 'tmid','sigma'])
    elif model == 'SPL_flare':     
        lc_fit_params=Table([grb_ref,band_list,F0_list,norm_list,alpha_list,t0_list,chi2_list,F0f_list,tmid_list,sigma_list],
                            names=['name','band','F0','norm','alpha','t0','chi2','F0f', 'tmid','sigma'])
    print (lc_fit_params)
    return lc_fit_params

In [None]:
def plot_lc_fit_check(observations, spline,lc_fit_params, model, plot,output_dir='results/', filename_suffix=''):
    """ Plot the fitting light curves 
    """
    grb_name='091020'
    wvl_list=[]
    #for dat in observations.group_by(['Time']).groups.keys:
    #    time_list.append(dat[0])
    for dat in spline.group_by(['eff_wvl']).groups.keys:
        wvl_list.append(dat[0])
        
    #print (obs_table)
    z_sim = 1.71
    Av_sim = 0

    # Set color for plots
    cmap = cm.rainbow(np.linspace(0,1,len(lc_fit_params['band'])))
    #cmap=['blue','green','yellow','orange','red','black']
    #print (len(lc_fit_params['band']))
    #colors = [cmap(i) for i in np.arange(len(lc_fit_params['band']))/10]

    plt.figure()
    
    #sort observations by eff. wavelength, telescope and band. keep just one time 
    for i,wvl in enumerate(wvl_list):
        mask_wvl = spline['eff_wvl']==wvl
        
        time = spline['Time'][mask_wvl]
        flux = spline['Flux'][mask_wvl]
        flux_err = np.ones(len(flux))*0.001
        
        xerr=np.zeros(len(time))
              
        #Select the fit parameters for the corresponding band and telescope
        mask2 = (lc_fit_params['name'] == grb_name) & (lc_fit_params['band'] == spline['filter'][mask_wvl][0]) 
        #print (i,observations['Time'][mask_wvl])

        #for t in range(len(time)):
            #plt.errorbar(time[t],flux[t],xerr=xerr[t],yerr=flux_err[t],
            #             label=observations['filter'][mask_wvl][t],color=cmap[i], markersize=10)
        
        #plot observations
        mask_obs=observations['eff_wvl']==wvl
        plt.errorbar(observations[mask_obs]['Time'],observations[mask_obs]['Flux'],
                     yerr=observations[mask_obs]['Flux_err'],color=cmap[i],fmt='o',markersize=3, label ='Obs.')
        plt.plot(time,flux,color=cmap[i],ls='--',lw=2, label='spline')    
        time_fit=np.linspace(time[0],time[-1],100)
        time_fit=np.linspace(1e3,1e5,1000)

        #print (lc_fit_params['F0'][mask2])
        
        if model == 'BPL':
            plt.plot(time_fit,BPL_lc(time_fit,float(lc_fit_params['F0'][mask2]),
                                     float(lc_fit_params['norm'][mask2]),float(lc_fit_params['alpha1'][mask2]),
                                     float(lc_fit_params['alpha2'][mask2]),float(lc_fit_params['t1'][mask2]),
                                     float(lc_fit_params['s'][mask2])),label=spline['filter'][mask_wvl][0],
                     color=cmap[i])
        elif model == 'SPL':
            plt.plot(time_fit,SPL_lc(time_fit,float(lc_fit_params['F0'][mask2]),float(lc_fit_params['t0'][mask2]),
                                     float(lc_fit_params['norm'][mask2]),float(lc_fit_params['alpha'][mask2])),
                     label=spline['filter'][mask_wvl][0],color=cmap[i])
        elif model == 'BPL_flare':
            plt.plot(time_fit,BPL_lc_flare(time_fit,float(lc_fit_params['F0'][mask2]),
                                     float(lc_fit_params['norm'][mask2]),float(lc_fit_params['alpha1'][mask2]),
                                     float(lc_fit_params['alpha2'][mask2]),float(lc_fit_params['t1'][mask2]),
                                     float(lc_fit_params['s'][mask2]), float(lc_fit_params['F0f'][mask2]),
                                     float(lc_fit_params['tmid'][mask2]), float(lc_fit_params['sigma'][mask2])),
                     label=spline['filter'][mask_wvl][0],color=cmap[i])
        elif model == 'SPL_flare':
            plt.plot(time_fit,SPL_lc_flare(time_fit,float(lc_fit_params['F0'][mask2]),float(lc_fit_params['t0'][mask2]),
                                           float(lc_fit_params['norm'][mask2]),float(lc_fit_params['alpha'][mask2]),
                                           float(lc_fit_params['F0f'][mask2]),float(lc_fit_params['tmid'][mask2]),
                                           float(lc_fit_params['sigma'][mask2])),
                     label=spline['filter'][mask_wvl][0],color=cmap[i])
        
    #plt.gca().invert_yaxis()
    #plt.xlim(obs_table['time_since_burst'][0]-60,obs_table['time_since_burst'][-1]+90)
    #plt.ylim(0,230)
    plt.xscale('log')
    plt.yscale('log')
    #print (time[0],z_sim,Av_sim)
    plt.title('Light curve from T-To=%.0f to T-To=%.0f sec \n z=%.2f \n GRB%s' % 
              (np.min(spline['Time']),np.max(spline['Time']),
               float(z_sim),grb_name))
    plt.xlabel(r'T-T$_{0}$ [seconds]')
    plt.ylabel(r'Flux [mJy]')
    #plt.axvline(305,color='red',lw=3)
    # do not duplicate legends and only one point in the label
    handles, labels = plt.gca().get_legend_handles_labels()
    by_label = OrderedDict(zip(labels, handles))
    plt.legend(by_label.values(), by_label.keys(),numpoints =1,loc='best')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("lc_fit_"+grb_name+'_'+model+filename_suffix+"2.png")

In [None]:
# Extract SED directly from spline
from scipy.interpolate import interp1d
from grb_photoz.extinction_correction import sed_extinction

def fit_SED(data, time_SED, method = 'spline', fit_law='SPL', plot=True):
    """Sed can be extracted using spline only or one parametric fit for each band.
    Method = 'spline' or 'fit'
    """
    
    SED_list=[]
    wvl_list=[]
    if method == 'spline':
        # This method uses only the spline data 
        # No extrapolation is perfromed outside the time of observations
        
        #for dat in data.group_by(['Time']).groups.keys:
        #    time_list.append(dat[0])
        for dat in data.group_by(['eff_wvl']).groups.keys:
            wvl_list.append(dat[0])
    
        #if time_SED>1187 or time_SED < 127: wvl_list=wvl_list[1:-1]
     
        wvl_good_range =[]
        for i,wvl in enumerate(wvl_list):
            #print (wvl)
        
            mask_wvl = data['eff_wvl']==wvl
        
            time = data['Time'][mask_wvl]
            flux = data['Flux'][mask_wvl]
            flux_err = np.ones(len(flux))
        
            time_interp=np.linspace(time[0],time[-1],1000)    
        
            #linear interpolation
            interp_flux = interp1d(time,flux,kind='linear')
            """
            plt.figure()
            plt.loglog(time_interp,interp_flux(time_interp),label='interp',ls='--',lw=3)
            plt.loglog(time,flux,label='data')
            """
            try:
                SED_list.append(interp_flux(time_SED))
                wvl_good_range.append(wvl)
            except:
                pass
    
        yerr=np.ones(len(wvl_good_range))*0.01
    
       
    elif method == 'fit':
        # In this method we use the best parametric fit for each band
        # It allows to extrapolate in time. But it is more difficult to fit all
        # bands correctly.
        
        for band in data['band']:
            wvl_list.append(eff_wvl(band))
    
        #if time_SED>1187 or time_SED < 127: wvl_list=wvl_list[1:-1]
     
        wvl_good_range =[]
        for i,wvl in enumerate(wvl_list):
            #print (wvl)
            if fit_law == 'SPL':
                # fit params already sorted by wavelength
                flux = SPL_lc(time_SED,float(data['F0'][i]),float(data['t0'][i]),
                              float(data['norm'][i]),float(data['alpha'][i]))
        
            flux_err = np.ones(len(wvl_list))
        
            try:
                SED_list.append(flux)
                wvl_good_range.append(wvl)
            except:
                pass
    
        yerr=np.ones(len(wvl_good_range))*1e-3

        
    if GRB=='080607': 
        z=3.03
        Av=1.5
        beta=3
    elif GRB=='061126': 
        z=0.168
        Av=0.0
        beta=1
    elif GRB=='091020':
        z=1.71
        Av=0
        beta=1
    #print (wvl_good_range)
    #print (z)
    F0_guess = SED_list[-1]
    norm_guess=1
    chi2_func = Chi2Functor_sed(Flux_template,np.array(wvl_good_range),np.array(SED_list),yerr)
    kwdarg = dict(pedantic=False,print_level=0,
                  F0=F0_guess, fix_F0=True,
                  wvl0=wvl_list[-1], fix_wvl0=True,
                  norm=norm_guess,fix_norm=False,limit_norm=(0.1,10),
                  beta=beta, fix_beta=False, limit_beta=[0,4],
                  z=z, fix_z=True,
                  Av=Av, fix_Av=True, limit_Av=[0,3])
    m = Minuit(chi2_func,**kwdarg)
    m.set_strategy(1)
    #m.migrad(nsplit=1,precision=1e-10)
    d,l=m.migrad()
    #print (band)
    #print ('Valid Minimum: %s ' % str(m.migrad_ok()))
    #print ('Is the covariance matrix accurate: %s' % str(m.matrix_accurate()))
    #Flux_template(wvl,F0,wvl0,norm,beta,z,Av)
        
    F0_fit=m.values['F0']
    wvl0_fit=m.values['wvl0']
    norm_fit=m.values['norm']
    #chi2_fit=d.fval
    beta_fit=m.values['beta']
    Av_fit=m.values['Av']
    z_fit=m.values['z']
    #Flux_template(wvl,F0,wvl0,norm,beta,z,Av,ext_law,Host_dust,Host_gas,MW_dust,MW_gas,DLA,igm_att)
    #print (m.values)
    
    wvl_fit=np.linspace(4000,9000,200)
    
    sed = Flux_template(wvl_fit,F0_fit,wvl0_fit,norm_fit,beta_fit,z_fit,Av_fit)
    
    if plot == True:
        plt.figure()
        plt.plot(wvl_fit, Flux_template(wvl_fit,F0_fit,wvl0_fit,norm_fit,beta_fit,z_fit,Av_fit))
        plt.errorbar(wvl_good_range,SED_list,xerr=0,yerr=yerr,fmt='o')
        #plt.xscale('log')
        plt.yscale('log')
        plt.xlabel(r'$\lambda$ [Angstroms]')
        plt.ylabel(r'Flux [$\mu$Jy]')
        plt.tight_layout()
    return sed


In [None]:
fit_test=fit_lc(observations,observations_spline,'SPL',method='best')

In [None]:
fit_test

In [None]:
plot_lc_fit_check(observations, observations_spline, fit_test, 'SPL', True,output_dir='', filename_suffix='')


In [None]:
best_sed=fit_SED(fit_test, 19000,method='fit',plot=True)

In [None]:
# Write light curve in file
time_ = np.linspace(1e3,1e5,1e3) 
wvl_fit=np.linspace(3900,8000,100)

lc=[]
lc.append(time_)
lc.append(wvl_fit)
for t in time_:
    print (t)
    sed = fit_SED(fit_test, t, method='fit',plot=False)
    lc.append(sed)

#print (lc)
time_list=[]
wvl_list=[]
flux_list=[]
for i, t in enumerate(lc[0]):
    for j,wvl in enumerate(lc[1]):
        time_list.append(t)
        wvl_list.append(wvl)
        flux_list.append(lc[i+2][j])
         
lc_table=Table([time_list,wvl_list,flux_list],names=['Time','wvl','flux'])

In [None]:
lc_table.write(os.getenv('pyETC_DIR')+'/pyETC/data/LightCurves/LC_%s_VT_blue.txt' % GRB, format ='ascii',overwrite=True)

In [None]:
# Compute light curve for Colibri and VT

from pyETC.pyETC import etc

#time = np.logspace(3.01,4.9,15)
time = np.logspace(3.1,4.9,10)
time_VT=(time[1:]+time[:-1])/2
time_VT=np.append(time_VT,9e4)
#time_VT[-2]+=200
mag_colibri=[]
mag_VT=[]
snr_colibri=[]
snr_VT=[]

telescopes=['colibri', 'VT']


for tel in telescopes:
    ETC=etc(configFile=configFile,name_telescope=tel)
    
    ETC.information['etc_type']='snr'
    ETC.information['object_type']='grb_sim'
    ETC.information['object_folder']='/data/LightCurves/'
    ETC.information['object_file']='LC_%s_VT_blue.txt' % GRB
    ETC.information['grb_model']='LightCurve'

    if tel == 'colibri':
        ETC.information['channel']='DDRAGO-B'
        ETC.information['filter_band']='r'
        ETC.information['lambda_start']= 0.51
        ETC.information['lambda_end']= 0.69
        ETC.information['lambda_step']= 0.001
        ETC.information['exptime'] = 30

    elif tel == 'VT':
        ETC.information['channel']='VIS-B'
        ETC.information['filter_band']='blue'
        ETC.information['lambda_start']= 0.39
        ETC.information['lambda_end']= 0.7
        ETC.information['lambda_step']= 0.001    
        ETC.information['exptime'] = 100


    for i,t in enumerate(time):
        if tel == 'colibri':
            ETC.information['t_sinceBurst']=t
        elif tel == 'VT':
            ETC.information['t_sinceBurst']=time_VT[i]

        ETC.sim()
        if tel == 'colibri':
            mag_colibri.append(ETC.information['mag'])
            snr_colibri.append(ETC.information['SNR'])
        elif tel == 'VT':
            mag_VT.append(ETC.information['mag'])
            snr_VT.append(ETC.information['SNR'])


In [None]:
mag_colibri=np.array(mag_colibri)
mag_colibri_err=1/np.array(snr_colibri)
Flux_colibri= 3631 * 10**(-0.4*mag_colibri)
Flux_err_colibri = Flux_colibri * 0.4 *np.log(10)*mag_err

mag_VT=np.array(mag_VT)
mag_VT_err=1/np.array(snr_VT)
Flux_VT= 3631 * 10**(-0.4*mag_VT)
Flux_err_VT = Flux_VT * 0.4 *np.log(10)*mag_err

In [None]:
# Plot light curve
fig, ax = plt.subplots()
ax.errorbar(time+30/2,Flux_colibri,xerr=30/2,yerr=Flux_err_colibri,label='Colibri r',color='red', fmt='o',ms=2)
ax.errorbar(time_VT+100/2,Flux_VT,xerr=30/2,yerr=Flux_err_VT,label='VT blue',color='blue', fmt='o',ms=2)

mn, mx =ax.set_ylim(1e-6,1e-3)
ax.set_ylabel(r'Flux [Jy]')
plt.xscale('log')
plt.yscale('log')
ax2 = ax.twinx()
ax2.set_ylim(-2.5*np.log10(mn/3631), -2.5*np.log10(mx/3631))
ax2.set_ylabel('AB mag')

ax.set_xlim(1e3,1.2e5)
plt.title('GRB091020 light curves for Colibri and VT')
ax.set_xlabel('Time since GRB trigger [s]')
ax.legend()
plt.tight_layout()
plt.savefig('GRB%s_LC_VT100.png' % GRB)

In [None]:
# save in a file
res = np.array([time+30/2, np.ones(len(time))*30,Flux_colibri, Flux_err_colibri, mag_colibri, mag_colibri_err]).T
res_VT = np.array([time_VT+100/2, np.ones(len(time))*100, Flux_VT, Flux_err_VT, mag_VT, mag_VT_err]).T

np.savetxt('LC_Colibri.dat', res, delimiter=',')
np.savetxt('LC_VT.dat', res_VT, delimiter=',')

In [None]:
if GRB == '061126':
    ra = 86.625
    dec = 64.190
elif GRB == '080607':
    ra = 194.967
    dec = 15.900
elif GRB == '091020':
    ra = 138.133
    dec = 67.167

# Make PSF

In [None]:
from pyETC.pyETC import etc
from ImSimpy.utils.PSFUtils import createPSF, convolvePSF


name_telescope='colibri'

#Parameters
PSF_size=[256,256]

oversampling=1
oversamp2=15
ETC=etc(configFile=configFile,name_telescope=name_telescope)

# Select the filter bands
#bands=['g','r','i','z','J','H']    
if name_telescope == 'colibri':
    bands=['r']
elif name_telescope == 'VT':
    bands=['red']
elif name_telescope == 'GWAC':
    bands=['R']
    
name_inst=os.getenv('ImSimpy_DIR')+'/ImSimpy/data/psf/instrument/%s/instrument' % output_dir
name_atm=os.getenv('ImSimpy_DIR')+'/ImSimpy/data/psf/atmosphere/%s/atmosphere' % output_dir

for band in bands:
    if name_telescope == 'colibri':
        if band in ['g','r','i','blue']:
            ETC.information['channel']='DDRAGO-B'
        elif band in ['z','y','red']:
            ETC.information['channel']='DDRAGO-R'
        elif band in ['J','H']:
            ETC.information['channel']='CAGIRE'
            
    elif name_telescope == 'VT':
        if band in ['blue']:
            ETC.information['channel']='VIS-B'
        elif band in ['red']:
            ETC.information['channel']='VIS-R'
    
    ETC.information['filter_band']=band
    
    # Compute ETC to get the seeing along the line of sight scale to airmass and wavelength
    ETC.sim()
    seeing = ETC.information['seeing_los_arcsec']   # in arcsec
    
    pixsize=ETC.information['cameras'][ETC.information['channel']]['Photocell_SizeX']
    pixscale=ETC.information['pixelScale_X']    # assume same in Y
    wvl=ETC.information['effWavelength']
    DM1=ETC.information['D_M1']
    DM2=ETC.information['D_M2']
    f_length=ETC.information['foc_len']

    print ('band: %s   PixSize: %.2f   PixScale: %.2f   wvl_eff: %.2f   \nDM1: %.2f   DM2: %.2f    f_Length: %.2f  seeing: %.2f' % (band,pixsize,pixscale,wvl,DM1,DM2,f_length,seeing))

    if name_telescope == 'VT':
        # Compute total PSF taking a gaussian with sigma = 0.4" (optics) * 0.4" (platform) * 1" (jitter) = 1.18"
        # --> FWHM =2.78"
        # Actually jitter is not taken into account, so 0.4 * 0.4 = 0.57" for sigma and 1.33" for FWHM
        
        # Compute moffat
        PSF_type='gaussian'
        name_atm_band=name_atm+'_%s_%s.fits' % (name_telescope,band)
        createPSF(filename=name_atm_band,PSF_type=PSF_type,imsize=PSF_size,pixel_size=[pixsize,pixsize],
                   pixel_scale=pixscale,eff_wvl=wvl,seeing=0.94,DM1=DM1,DM2=DM2,focal_length=f_length,
                   oversamp=oversampling,oversamp2=oversamp2,beta=2,disp=False,unsigned16bit=False)
    
        # Compute INstrumental PSF using ideal AIry function
        PSF_type='gaussian'
        name_inst_band=name_inst+'_%s_%s.fits' % (name_telescope,band)
        createPSF(filename=name_inst_band,PSF_type=PSF_type,imsize=PSF_size,pixel_size=[pixsize,pixsize],
                   pixel_scale=pixscale,eff_wvl=wvl,seeing=0.94,DM1=DM1,DM2=DM2,focal_length=f_length,
                   oversamp=oversampling,oversamp2=oversamp2,beta=2,disp=False,unsigned16bit=False)        
        
        name_PSF_total=os.getenv('ImSimpy_DIR')+'/ImSimpy/data/psf/total_PSF/%s/PSF_total_%s_%s.fits' % (output_dir,name_telescope,band)
        convolvePSF(filename1=name_atm_band,filename2=name_inst_band,filename3=name_PSF_total)

    elif name_telescope == 'colibri':
        # Compute Atmosphere PSF using moffat
        PSF_type='moffat'
        name_atm_band=name_atm+'_%s_%s.fits' % (name_telescope,band)
        createPSF(filename=name_atm_band,PSF_type=PSF_type,imsize=PSF_size,pixel_size=[pixsize,pixsize],
                   pixel_scale=pixscale,eff_wvl=wvl,seeing=seeing,DM1=DM1,DM2=DM2,focal_length=f_length,
                   oversamp=oversampling,oversamp2=oversamp2,beta=2,disp=False,unsigned16bit=False)
    
        # Compute INstrumental PSF using ideal AIry function
        PSF_type='airy'
        name_inst_band=name_inst+'_%s_%s.fits' % (name_telescope,band)
        createPSF(filename=name_inst_band,PSF_type=PSF_type,imsize=PSF_size,pixel_size=[pixsize,pixsize],
                   pixel_scale=pixscale,eff_wvl=wvl,seeing=seeing,DM1=DM1,DM2=DM2,focal_length=f_length,
                   oversamp=oversampling,oversamp2=oversamp2,beta=2,disp=False,unsigned16bit=False)
    
        # Convolve Total PSF
        name_PSF_total=os.getenv('ImSimpy_DIR')+'/ImSimpy/data/psf/total_PSF/%s/PSF_total_%s_%s.fits' % (output_dir,name_telescope,band)
        convolvePSF(filename1=name_atm_band,filename2=name_inst_band,filename3=name_PSF_total)
        
    elif name_telescope == 'GWAC':
        # Compute Atmosphere PSF using moffat
        PSF_type='moffat'
        name_atm_band=name_atm+'_%s_%s.fits' % (name_telescope,band)
        createPSF(filename=name_atm_band,PSF_type=PSF_type,imsize=PSF_size,pixel_size=[pixsize,pixsize],
                   pixel_scale=pixscale,eff_wvl=wvl,seeing=seeing,DM1=DM1,DM2=DM2,focal_length=f_length,
                   oversamp=oversampling,oversamp2=oversamp2,beta=2,disp=False,unsigned16bit=False)
    
        # Compute INstrumental PSF using ideal AIry function
        PSF_type='airy'
        name_inst_band=name_inst+'_%s_%s.fits' % (name_telescope,band)
        createPSF(filename=name_inst_band,PSF_type=PSF_type,imsize=PSF_size,pixel_size=[pixsize,pixsize],
                   pixel_scale=pixscale,eff_wvl=wvl,seeing=seeing,DM1=DM1,DM2=DM2,focal_length=f_length,
                   oversamp=oversampling,oversamp2=oversamp2,beta=2,disp=False,unsigned16bit=False)
    
        # Convolve Total PSF
        name_PSF_total=os.getenv('ImSimpy_DIR')+'/ImSimpy/data/psf/total_PSF/%s/PSF_total_%s_%s.fits' % (output_dir,name_telescope,band)
        convolvePSF(filename1=name_atm_band,filename2=name_inst_band,filename3=name_PSF_total)
     

# Compute images

In [None]:
from ImSimpy.ImSimpy import ImageSimulator

name_telescope='colibri'
    
time_grb=1e4
expTime=100


IS=ImageSimulator(configFile=configFile,name_telescope=name_telescope)

#Read the configfile
IS.readConfigs()
IS.information['etc_type']='snr'
IS.information['object_type']='grb_sim'
IS.config['object_folder']='/data/LightCurves/'
IS.information['object_file']='LC_%s.txt' % GRB
IS.information['grb_model']='LightCurve'

IS.config['grb_coord_type']= 'RADEC'
IS.config['grb_coords']= [ra,dec]

#Position of the reference pixel centered on GRB position
IS.config['RA']= ra
IS.config['DEC']= dec

IS.config['lambda_start']= 0.51
IS.config['lambda_end']= 0.9
IS.config['lambda_step']= 0.001


bands = ['g','r','i','z','J','H']
if name_telescope == 'colibri':
    bands = ['r']
elif name_telescope == 'VT':
    bands = ['red']

for band in bands:
    
    print ('\n')
    if name_telescope == 'colibri':
        if band in ['g', 'r', 'i']:
            IS.config['channel']='DDRAGO-B'
            IS.config['GainMapFile']='%s/Gain_vis.fits' % output_dir
            IS.config['VignettingFile']='%s/Vignetting_vis.fits' % output_dir
            IS.config['OffsetFile']='%s/Offset_vis.fits' % output_dir
            IS.config['DeadPixFile']='%s/DeadPixs_vis.fits' % output_dir
            IS.config['HotPixFile']='%s/HotPixs_vis.fits' % output_dir
            IS.config['SourcesList']['generate']['catalog'] = 'Panstarrs'
            IS.config['SourcesList']['generate']['radius'] = 0.3 # in degrees, no need to get all the sources

        elif band in ['y', 'z']:
            IS.config['channel']='DDRAGO-R'
            IS.config['GainMapFile']='%s/Gain_vis.fits' % output_dir
            IS.config['VignettingFile']='%s/Vignetting_vis.fits' % output_dir
            IS.config['OffsetFile']='%s/Offset_vis.fits' % output_dir
            IS.config['DeadPixFile']='%s/DeadPixs_vis.fits' % output_dir
            IS.config['HotPixFile']='%s/HotPixs_vis.fits' % output_dir
            IS.config['SourcesList']['generate']['catalog'] = 'Panstarrs'
            IS.config['SourcesList']['generate']['radius'] = 0.3 # in degrees, no need to get all the sources

        elif band in ['J', 'H']:
            IS.config['channel']='CAGIRE'
            IS.config['GainMapFile']='%s/Gain_nir.fits' % output_dir
            IS.config['VigenttingFile']='%s/Vignetting_nir.fits' % output_dir
            IS.config['OffsetFile']='%s/Offset_nir.fits' % output_dir
            IS.config['DeadPixFile']='%s/DeadPixs_nir.fits' % output_dir
            IS.config['HotPixFile']='%s/HotPixs_nir.fits' % output_dir
            IS.config['SourcesList']['generate']['catalog'] = 'II/246'
            IS.config['SourcesList']['generate']['radius'] = 26 # in arcmin

    elif name_telescope == 'VT':
        if band =='blue':
            IS.config['channel']='VIS-B'
        elif band == 'red':
            IS.config['channel']='VIS-R'
        IS.config['SourcesList']['generate']['catalog'] = 'NOMAD-1'
        IS.config['SourcesList']['generate']['radius'] = 26 # in arcmin


    IS.config['exptime']=expTime
    IS.config['filter_band']=band
    IS.config['t_sinceBurst']=time_grb
    
    IS.config['SourcesList']['generate']['band'] = band
    IS.config['SourcesList']['generate']['RA'] = ra
    IS.config['SourcesList']['generate']['DEC'] = dec

    IS.config['SourcesList']['generate']['output']="%s/Sources_%s.txt" % (output_dir,band)
    IS.config['output']='%s/image_%s_%s_%s.fits' % (output_dir,name_telescope,band,time_grb)
    IS.config['PSF']['total']['method']='compute'
    IS.config['PSF']['total']['file']='total_PSF/%s/PSF_total_%s.fits' % (output_dir,band)
    #IS.config['psfoversampling']= oversamp[obs['band']]
    IS.simulate('data')
    if band in ['J', 'H']:
        grb_coords_pix_X_nir=int(IS.config['grb_coords_pix_X'])
        grb_coords_pix_Y_nir=int(IS.config['grb_coords_pix_Y'])   
    else:
        grb_coords_pix_X_vis=int(IS.config['grb_coords_pix_X'])
        grb_coords_pix_Y_vis=int(IS.config['grb_coords_pix_Y'])
    print (IS.config['grb_mag'])

In [None]:
print (grb_coords_pix_X_vis, grb_coords_pix_Y_vis)
#print (grb_coords_pix_X_nir, grb_coords_pix_Y_nir)

In [None]:
from astropy.io import fits
import pyregion
from matplotlib.colors import LogNorm
from astropy.visualization import MinMaxInterval, SqrtStretch,LogStretch,AsinhStretch, LinearStretch,ImageNormalize,ZScaleInterval

# Select the bands in which you want to see the GRB in the images
#bands=['g','r','i','z','J','H']
#times=['100', '900', '1800', '3600']
#labels=('100s', '15m', '30m', '1h')

name_telescope='colibri'

if name_telescope == 'colibri':
    bands=['r']
    times=['10000.0']
    labels=('10000s')

elif name_telescope == 'VT':
    bands=['red']
    times=['10000.0']
    labels=('10000s')

fig, ax = plt.subplots(len(times)+1,len(bands)+1,figsize=(10,6))

fig.subplots_adjust(hspace=0.05, wspace=0.01)

# Select the images you want to display 
for j, time_grb in enumerate(times):

    #bands_images=['image_g_%s' % time_grb,'image_r_%s' % time_grb,'image_i_%s' % time_grb,'image_z_%s' % time_grb,'image_J_%s' % time_grb,'image_H_%s' % time_grb]
    bands_images=['image_%s_%s_%s' % (name_telescope, bands[0],time_grb)]


    #DS9 regions
    region_string = """
       # Region file format: DS9 version 4.1
       global color=white dashlist=8 3 width=3
       image
       line(14,3,14,10) # line=0 0
       line(19,15,26,15) # line=0 0
       """

    for i,band in enumerate(bands):
        fname=bands_images[i]
        image=fits.getdata(os.getenv('ImSimpy_DIR')+'/ImSimpy/images/%s/%s.fits' % (output_dir,fname))
        #r, c = (i+1) // 2, (i+1) % 2
    
        if band in ['J','H']:
            # Center image on GRB position
            center=[int(grb_coords_pix_X_nir),int(grb_coords_pix_Y_nir)]
            # Width of the image centered on GRB
            width=[15,15]
        else:
            # Center image on GRB position
            center=[int(grb_coords_pix_X_vis),int(grb_coords_pix_Y_vis)]
            # Width of the image centered on GRB
            width=[15,15]

        image_grb=image[center[0]-width[0]:center[0]+width[0],center[1]-width[1]:center[1]+width[1] ]

        # Based on IRAF's zscale
        vmin,vmax=ZScaleInterval(nsamples=1000, contrast=0.25, max_reject=0.5,
                             min_npixels=5, krej=2.5, max_iterations=5).get_limits(image)
        #norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LogStretch())
        #norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=SqrtStretch())
        #norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=AsinhStretch())
        norm = ImageNormalize(vmin=vmin, vmax=vmax, stretch=LinearStretch())

        #print (center[0]-width[0],center[0]+width[0],center[1]-width[1],center[1]+width[1])
        ax[j,i].imshow(image_grb,interpolation='none',cmap='gray',norm=norm)#,origin='lower')
        #ax[j,i].imshow(image_grb/np.max(image_grb),interpolation='none',origin='lower',cmap='jet',vmin=0,vmax=1)
        if j == len(times)-1:
            ax[j,i].text(0.55,0.1, '%s' % bands[i], transform=ax[j,i].transAxes, ha='right',color='white',fontsize=18,fontweight='bold')
        ax[j,0].text(0.97,0.8, '%s' % labels[j], transform=ax[j,i].transAxes, ha='right',color='white',fontsize=12,fontweight='bold')
        ax[j,i].axis('off')
    
        region = pyregion.parse(region_string)
        mask = region.get_mask(shape=image_grb.shape)
        patch_list, text_list = region.get_mpl_patches_texts()
        for patch in patch_list:
               ax[j,i].add_patch(patch)
        for text in text_list:
               ax[j,i].add_artist(text)
    
plt.savefig(os.getenv('ImSimpy_DIR')+'/ImSimpy/images/%s/niceplot.png' % output_dir)