In [None]:
#Making a WR-purposed MCMC fitting suite - to optimise fittings of TESS light curves of WR stars
import os, sys

import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator
import matplotlib.ticker as ticker

import numpy as np

import astropy
from astropy.io import fits
from astropy.io import ascii
import astropy.units as u
# from astropy.utils.data import get_pkg_data_filename   #time-series extraction tool
# from astropy.timeseries import TimeSeries

import math 

import lightkurve as lk    #good time-series library

import pandas as pd 

from scipy.optimize import curve_fit
from scipy.optimize import minimize

import emcee
import corner

from gls import Gls

In [None]:
#Searching lightcurves
def search_lc(name, mission, exptime, radius):
    search_result = lk.search_lightcurve(name, mission=mission, exptime=exptime, radius=radius)
    #print (search_result)
    return search_result

In [None]:
#Auto-generate text files of light curves
def lc_dat_generate(search_result, name):
    lc_data = search_result.download_all()
    for i in range (len(lc_data)):
        lc_data[i].write("{}_lc{}.txt".format(name, i), format="ascii.fixed_width", overwrite=True)

In [None]:
def lc_view(lc_name, fold_factor, wrap_no, save_fig):
    for k in lc_name:
        search_res = search_lc(name=k, mission='TESS', exptime=120, radius=60)
        for i in range (len(search_res)):
            lc = search_res[i].download()
            print ('Currently handling '+ k + '_' + str(i))
            pg_lc = lc.normalize(unit='ppm').to_periodogram()    #to obtain period at max power
            max_p = pg_lc.period_at_max_power
            
            #Generate light curve and after folding
            #fig, ax1, ax2, ax3 = plt.subplots(3, 1, figsize=(12,6))
            ax1 = lc.plot(column='pdcsap_flux', label='PDCSAP Flux', color='teal', normalize=True)    #PDCSAP flux
            ax1.set_title('PDCSAP Light Curve of {}_lc{}'.format(k, i))
            ax2 = lc.fold(period=max_p).scatter()
            ax2.set_title('{}_lc{} - Folded at max period'.format(k, i))
            ax3 = lc.fold(period=fold_factor*max_p, wrap_phase=wrap_no).scatter()
            ax3.set_title('{}_lc{} - {}*period_max - wrap_phase = {}'.format(k, i, fold_factor, wrap_no))
            
            #Save plots to correct folders
            if save_fig:
                new_dir = '{}'.format(k)
                current_loc = os.getcwd()          #get current working directory
                gen_path = os.path.join(current_loc, new_dir)
                new_subdir = '{}_lc{}'.format(k, i)
                lc_path = os.path.join(gen_path, new_subdir)
                #saving each axes separately
                lc_sig = 'pdcsap_{}_lc{}.png'.format(k, i)
                f_lc = '{}_lc{}_folded.png'.format(k, i)
                fw_lc = '{}_lc{}_folded_wrap.png'.format(k, i)
                
                ax1.figure.savefig(os.path.join(lc_path, lc_sig), bbox_inches='tight')
                ax2.figure.savefig(os.path.join(lc_path, f_lc), bbox_inches='tight')
                ax3.figure.savefig(os.path.join(lc_path, fw_lc), bbox_inches='tight')

In [None]:
lc_view(all_wr_list, 4, 0.2, True)

In [None]:
#lc - light curves (array), mission - TESS, exptime - 120s, radius - 60"
def LS_gen(lc_name, mission, exptime, radius):
    #iterate through entire database of WRs
    for k in lc_name:
        search_res = search_lc(name=k, mission='TESS', exptime=120, radius=60)
        
    
#         #array storing estimates of fit params
#         dat_len_rec = []
#         a0_rec = []
#         a1_rec = []
#         a3_rec = []

        for i in range (len(search_res)):
            lc = search_res[i].download()
            print ('Currently handling '+ k + '_' + str(i))

            #making the LS periodogram
            pg_lc = lc.normalize(unit='ppm').to_periodogram()
            dat_len = len(pg_lc.power.value)    #need length to adjust accordingly later
            a0_est = np.mean(pg_lc.power[0:10].value)    #estimate for a0
            a1_est = 1/pg_lc.period_at_max_power.value   #estimate for characteristic freq. (in days)
            a3_est = np.mean(pg_lc.power[-20000:].value)   #estimate for white noise term
            print (dat_len, a0_est, a1_est, a3_est)
            
#             dat_len_rec.append(dat_len)
#             a0_rec.append(a0_est)
#             a1_rec.append(a1_est)    
#             a3_rec.append(a3_est)

            #saving LS to appropriate folder(s)
            ax = pg_lc.plot(view='frequency', scale='log')
            ax.set_title('Lomb-Scargle periodogram - ' + k +'- LC' + str(i))

            new_dir = '{}'.format(k)
            current_loc = os.getcwd()          #get current working directory
            gen_path = os.path.join(current_loc, new_dir)
            my_LS = 'LS_full_{}_lc{}.png'.format(k, i)

            #Saving into sub-directories each target's lc
            new_subdir = '{}_lc{}'.format(k, i)
            lc_path = os.path.join(gen_path, new_subdir)

            #preventing already-existing errors
            try:
                os.makedirs(lc_path, exist_ok = True)
                #print("Directory '%s' created successfully" % new_subdir)
            except OSError as error:
                print("Directory '%s' can not be created" % new_subdir)


            #Saving it in the correct sub-dir
            plt.savefig(os.path.join(lc_path, my_LS), bbox_inches='tight')
        

In [None]:
#Function to edit and shorten periodogram for MCMC fittings

#lc_name - light curve name (str), fi - initial freq, ff - final freq
def LS_edit(lc_name, fi, ff, step, bin_bool):
    for k in lc_name:
        search_res = search_lc(name=k, mission='TESS', exptime=120, radius=60)
        
        for i in range (len(search_res)):
            lc = search_res[i].download()
            print ('Currently handling '+ k + '_' + str(i))
            
            #Modifying the LS - reducing and binning if necessary
            freq_range = np.linspace(fi, ff, step)
            pg_lc = lc.normalize(unit='ppm').to_periodogram(frequency=freq_range)
            if bin_bool:
                pg_lc = pg_lc.bin(binsize=15, method='mean')
            
#             #Getting x-y values for later MCMC fitting
#             x = getattr(pg_lc, 'frequency')   #frequency array
#             y = getattr(pg_lc, 'power')       #power (ppm) array
            
#             #converting to scalars
#             freq = [float(i/(1*x.unit)) for i in x]        
#             amp_pow = [float(j/(1*y.unit)) for j in y]
            
            #Generate plot and save
            plt.ioff()   #don't display in kernel
            ax = pg_lc.plot(view='frequency', scale='log')
            ax.set_title('Reduced Lomb-Scargle - ' + k +'- LC' + str(i))
            
            new_dir = '{}'.format(k)
            current_loc = os.getcwd()          #get current working directory
            gen_path = os.path.join(current_loc, new_dir)
            my_LS = 'LS_rd_{}_lc{}.png'.format(k, i)

            #Saving into sub-directories each target's lc
            new_subdir = '{}_lc{}'.format(k, i)
            lc_path = os.path.join(gen_path, new_subdir)

#             #preventing already-existing errors
#             try:
#                 os.makedirs(lc_path, exist_ok = True)
#                 #print("Directory '%s' created successfully" % new_subdir)
#             except OSError as error:
#                 print("Directory '%s' can not be created" % new_subdir)


            #Saving it in the correct sub-dir
            plt.savefig(os.path.join(lc_path, my_LS), bbox_inches='tight')

        #return (freq, amp_pow)
    


In [None]:
#Test block
my_file = 'graph.png'
my_path = os.path.abspath('Auto_LS_WR_fit')
my_dir = 'something'
print (os.path.abspath('Auto_LS_WR_fit'))
print (os.path.abspath(os.path.join(my_path, my_file)))
cur_loc = os.getcwd()
path = os.path.join(cur_loc, my_dir)
os.makedirs(path)

In [None]:
wr_search = lk.search_lightcurve("WR16", mission="TESS", exptime=120, radius=60)
wr133_lc1 = wr_search[0].download()
pg = wr133_lc1.normalize(unit='ppm').to_periodogram()
print (len(pg.power.value))

In [None]:
all_wr_list = ['WR1', 'WR2', 'WR3', 'WR4', 'WR5', 'WR6', 'WR7', 'WR8', 'WR9', 'WR10', 'WR11', 'WR12', 'WR14', 'WR15', 'WR16', 'WR17', 'WR18', 'WR21', 'WR22', 'WR23', 'WR24', 'WR25', 'WR31', 'WR31A', 'WR31B', 'WR40', 'WR42', 
               'WR43A', 'WR43B', 'WR43C', 'WR46', 'WR47', 'WR48', 'WR50', 'WR52', 'WR53', 'WR55', 'WR57', 'WR59', 'WR66', 'WR67', 'WR69', 'WR70', 'WR71', 'WR75', 'WR78', 'WR79', 'WR79A', 'WR79B', 'WR85', 'WR86', 'WR87', 'WR89',
               'WR90', 'WR92', 'WR93', 'WR97', 'WR98', 'WR103', 'WR124', 'WR127', 'WR128', 'WR133', 'WR134', 'WR135', 'WR136', 'WR137', 'WR138', 'WR139', 'WR140', 'WR141', 'WR143', 'WR148', 'WR153', 'WR155', 'WR157', 'WR158', 'WR159']


In [None]:
#Execute command to generate all LS periodograms
LS_gen(all_wr_list, "TESS", 120, 60)

In [None]:
#Execute command to generate all reduced LS
LS_edit(all_wr_list, 0.01, 100, 5000, False)

In [None]:
#Defining likelihood, flat prior and combined to probability
#a1 = v_c, a2 = gamma, a3 = C_w // x = \nu

def model(theta, freq=freq):
    a0, a1, a2, a3 = theta
    return (a0/(1+(freq/a1)**a2)) + a3

def log_likelihood(theta, x, y, yerr):
    return -0.5 * np.sum(((y - model(theta, x))/yerr) ** 2) 

#define flat prior
def log_prior(theta):
    a0, a1, a2, a3 = theta
    if 10**3 < a0 < 10**5 and 10**-1 < a1 < 10**1 and 0 < a2 < 20 and 10**1 < a3 < 10**3:
        return 0.0
    return -np.inf

def log_probability(theta, x, y, yerr):
    lp = log_prior(theta)
    if not np.isfinite(lp):
        return -np.inf
    return lp + log_likelihood(theta, x, y, yerr) 

In [None]:
#Define 'true' values of a0, a1, a2, a3 and assuming an error of amplitude power
a0_true = 10**4
a1_true = 2
a2_true = 2
a3_true = 80
amp_err = 0.05*np.mean(amp_pow)

#Setting up for ensemble sampler
data = [freq, amp_pow, amp_err]
nwalkers = 150
niter = 3000
initial = np.array([a0_true, a1_true, a2_true, a3_true])
ndim = len(initial)
p0 = [np.array(initial) + 1e-6 * np.random.randn(ndim) for i in range(nwalkers)]

In [None]:
def main_MCMC(p0,nwalkers,niter,ndim,log_probability,data):
    sampler = emcee.EnsembleSampler(nwalkers, ndim, log_probability, args=data)

    print("Running burn-in...")
    p0, _, _ = sampler.run_mcmc(p0, 150)
    sampler.reset()

    print("Running production...")
    pos, prob, state = sampler.run_mcmc(p0, niter, progress=True)

    return sampler, pos, prob, state

sampler, pos, prob, state = main_MCMC(p0,nwalkers,niter,ndim,log_probability,data)

In [None]:
samples = sampler.flatchain
samples[np.argmax(sampler.flatlnprobability)]

In [None]:
def plotter(sampler,freq,amp_pow, max_like):
    plt.ion()
    plt.subplots(figsize=(10, 6))
    plt.yscale('log')
    plt.xscale('log')
    plt.plot(freq,amp_pow,label='LS periodogram')
    samples = sampler.flatchain
    
    #choose max-likelihood model
    if max_like:
        theta = samples[np.argmax(sampler.flatlnprobability)]
        plt.plot(freq, model(theta, freq), color='crimson', alpha=0.9)
    else:
        for theta in samples[np.random.randint(len(samples), size=100)]:
            plt.plot(freq, model(theta, freq), color="r", alpha=0.1)
        
    #plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    plt.xlabel('Freq')
    plt.ylabel('Amplitude Power')
    plt.legend()
    plt.show()

plotter(sampler, freq, amp_pow, True)

In [None]:
# tau = sampler.get_autocorr_time()
# print(tau)
flat_samples = sampler.get_chain(discard=180, thin=45, flat=True)
print(flat_samples.shape)
fig = corner.corner(
    flat_samples, labels=labels, truths=[a0_true, a1_true, a2_true, a3_true]
);

In [None]:
def residual_max(sampler, x, y):
    samples = sampler.flatchain
    #choose max-likelihood model
    theta = samples[np.argmax(sampler.flatlnprobability)]   
    res = y - model(theta, x)
#     res_err = np.median(res)
#     print (np.median(res))
    
    plt.subplots(figsize=(16, 6))
    plt.plot(x, res, color='crimson', alpha=0.8)
    #plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
    plt.xlim(0.1,40)
    plt.ylim(-100,500)
    plt.hlines(0,0,400, linestyles='dashed', alpha=0.5)
    plt.xlabel('Freq')
    plt.ylabel('Amplitude Power')
    plt.show()
    
plotter(sampler, freq, amp_pow, True)
residual_max(sampler, freq, amp_pow)