# Importing packages

In [None]:
!pip install astropy
!pip install tabulate
!pip install emcee
!pip install corner
!pip install lmfit

In [None]:
# -------------- Import packages --------------
import joblib
joblib.Parallel(n_jobs=12)

# Enable inline plotting in notebook
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
from numba import njit
import astropy

from matplotlib.colors import TwoSlopeNorm
import pandas as pd
import pickle
import arx
import time
import math
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import sys
import argparse
import scipy
import PYCCF as myccf
from scipy import stats 
np.set_printoptions(threshold=sys.maxsize)
from matplotlib import gridspec
from scipy import interpolate
from astropy import units as u
import matplotlib.ticker as mtick
import matplotlib.ticker as ticker
import emcee
from IPython.display import display, Math
import scipy as sp
from matplotlib.legend_handler import HandlerTuple
from datetime import datetime
from astropy.io import fits
from astropy.constants import b_wien
import os
import seaborn as sns

# from PyROA, available at https://github.com/Alymantara/PyROA
# Reference: Donnan et al. 2021 (https://ui.adsabs.harvard.edu/abs/2021arXiv210712318D/abstract)
import PyROA
import Utils

import xspec

# from RELAGN, available at https://github.com/scotthgn/RELAGN
# Reference: Hagen & Done (2023b) (https://ui.adsabs.harvard.edu/abs/2023arXiv230401253H/abstract)
from relagn import relagn

In [None]:
plt.rc('xtick', labelsize=15) 
plt.rc('ytick', labelsize=15)
plt.rc('axes', titlesize=20)     # fontsize of the axes title
plt.rc('axes', labelsize=20)
plt.rc('legend', fontsize=20)

# Intercalibration

## Function library
Reference: Vielute R. (in. prep.) Functions are included in their entirety to reflect how they have been used throughout the project

In [None]:

# ===================================== Function Library ============================================
def Clean(Star_File):
    """ Drop duplicate values, existing zp and zp_err columns, and strange errors (= 99.0).
    
    Input:
    Star_File - Original data file (pd dataframe).
    
    Output:
    File2 - Cleaned up data file (pd dataframe).
    """
    
    File2 = Star_File
    File2 = File2.drop_duplicates(subset=['id_apass', 'Filter', 'MJD', 'telid', 'airmass', 'seeing'], keep='first', inplace=False, ignore_index=False)
    File2 = File2.drop(File2[(File2['err_aper'] > 20)].index)
    return(File2)
    
def Locate_Star_ID(Star_File, Star_ID):
    """ Locate star(s) in the file by ID(s).
    
    Input:
    Star_File - Star data (pd dataframe).
    Star_ID - Star ID(s) (float or list).
    
    Output:
    Stars - pd dataframe containing selected stars by ID.
    """
    if type(Star_ID) == list:
        mask = Star_File['id_apass'].isin(Star_ID)
        Stars = Star_File[mask]
    else:
        Stars = Star_File.loc[Star_File['id_apass'] == Star_ID]
    return(Stars)

def Locate_Star_Filter(Star_File, Fltr):
    """ Locate star(s) in the file by filter(s).
    
    Input:
    Star_File - Star data (pd dataframe).
    Fltr - Filter(s) (string or list).
    
    Output:
    Stars - pd dataframe containing selected stars by filter.
    """
    if type(Fltr) == list:
        mask = Star_File['Filter'].isin(Fltr)
        Stars = Star_File[mask]
    else:
        Stars = Star_File.loc[Star_File['Filter'] == Fltr] 
    return(Stars)

def Locate_Star_Scope(Star_File, Telescope):
    """ Locate star(s) in the file by telescope(s).
    
    Input:
    Star_File - Star data (pd dataframe).
    Telescope - Telescope(s) (string or list).
    
    Output:
    Stars - pd dataframe containing selected stars by telescope.
    """
    if type(Telescope) == list:
        mask = Star_File['telid'].isin(Telescope)
        Stars = Star_File[mask]
    else:
        Stars = Star_File.loc[Star_File['telid'] == Telescope]
    return(Stars)

def Locate_Star_Epoch(Star_File, Epoch):
    """ Locate star(s) in the file by Epoch(s).
    
    Input:
    Star_File - Star data (pd dataframe).
    Epoch - MJD(s) (string or list).
    
    Output:
    Stars - pd dataframe containing selected stars by epoch.
    """
    if type(Epoch) == list:
        mask = Star_File['MJD'].isin(Epoch)
        Stars = Star_File[mask]
    else:
        Stars = Star_File.loc[Star_File['MJD'] == Epoch]
    return(Stars)

def Generic_Optimal(Data, Errors):
    """ Compute inverse variance weighted average.
    
    Input:
    Data - Data to average (list/array).
    Errors - Errors of data to average (list/array).
    
    Output:
    avg - Optimal average (float).
    err - standard deviation of average (float).
    """
    
    #errors handled by returning nan values
    #ensure data are in arrays
    var = np.array(Errors)**2
    Data = np.array(Data)
    
    if len(var) != 0:
        try:
            w = 1/var
            avg = np.nansum(w*Data) / np.nansum(w)
            var = 1/np.nansum(w)
        except:
            avg = np.nan
            var = np.nan
    else:
        avg = np.nan
        var = np.nan
    
    if var**0.5 == np.inf:
        var == np.nan
        avg == np.nan
    
    err = var**0.5
    return(avg, err)

def Brightest(Star_File, Filter, AGN_ID = AGN_ID):
    """ IDs of the brghtest stars in the data for a specific filter, identified using optimal average. 
    
    Input:
    Star_File - Star data (pd dataframe).
    Filter - Filter (string).
    AGN_ID - ID of AGN (int).
    
    Output:
    Brightest_IDs - IDs of stars sorted from brightest to dimmest (list).
    """
    
    #Select specified filter data
    Star_Data0 = Locate_Star_Filter(Star_File, Filter)
    
    #Make sure AGN isn't included
    Star_IDs = [k for k in pd.unique(Star_Data0['id_apass']) if k!= AGN_ID]

    #Compute optimal average of instrumental magnitudes
    MAGS = []
    IDS = []
    for ID in Star_IDs:
        Star_Data = Locate_Star_ID(Star_Data0, ID)
        mag = Star_Data['mag_aper'].values
        mag_err = Star_Data['err_aper'].values
        mag_mean = Generic_Optimal(mag, mag_err)[0]
        MAGS.append(mag_mean)
        IDS.append(ID)
    
    #Sort from brightest to dimmest and return IDs
    MAGS_ref = MAGS
    MAGS = sorted(MAGS)
    Brightest_indices = [MAGS_ref.index(i) for i in MAGS]
    Brightest_IDs = np.array(IDS)[np.array(Brightest_indices)]
    return Brightest_IDs

def root_mean_squared_deviation(data, mean):
    """Compute rms of dataset.
    Input:
    data - dataset to compute rms for (array/list).
    mean - mean of dataset (float).
    
    Output:
    rms - root mean squared deviation (float).
    """
    
    rms = (sum((data-mean)**2) / len(data))**0.5
    return(rms)

def Brightest_Reduced(Star_File, Filter, AGN_ID = AGN_ID, frac = 0.5):
    """ IDs of the brightest stars in the data for a specific filter, identified using optimal average. 
        Only include stars that have more than specified number of data points per star.
    
    Input:
    Star_File - Star data (pd dataframe).
    Filter - Filter (string).
    AGN_ID - ID of AGN (int).
    frac - (Optional) Fraction of max number of epochs to keep (float).
    
    Output:
    Brightest_IDs - IDs of stars sorted from brightest to dimmest (list).
    """
    
    #Select specified filter data
    Star_Data0 = Locate_Star_Filter(Star_File, Filter)
    
    #Make sure AGN isn't included
    Star_IDs = [k for k in pd.unique(Star_Data0['id_apass']) if k!= AGN_ID]
    
    Max_length = 0.0  
    #Get max number of epochs amongst all stars
    for ID in Star_IDs:
        Star_Data = Locate_Star_ID(Star_Data0, ID)
        if len(Star_Data['MJD'].values) > Max_length:
            Max_length = len(Star_Data['MJD'].values)
    
    #Select all stars with more than specified number of datapoints
    keep_dp = frac * Max_length
    Reduced_IDs = []
    for ID in Star_IDs:     
        Star_Data = Locate_Star_ID(Star_Data0, ID)
        if len(Star_Data['MJD'].values) > keep_dp:
            Reduced_IDs.append(ID)
    
    #Compute optimal average of instrumental magnitudes
    MAGS = []
    IDS = []
    for ID in Reduced_IDs:
        Star_Data = Locate_Star_ID(Star_Data0, ID)
        mag = Star_Data['mag_aper'].values
        mag_err = Star_Data['err_aper'].values
        mag_mean = Generic_Optimal(mag, mag_err)[0]
        MAGS.append(mag_mean)
        IDS.append(ID)
    
    #Sort from brightest to dimmest and return IDs
    MAGS_ref = MAGS
    MAGS = sorted(MAGS)
    Brightest_indices = [MAGS_ref.index(i) for i in MAGS]
    Brightest_IDs = list(np.array(IDS)[np.array(Brightest_indices)])
    return Brightest_IDs

def Epoch_Dist(Star_File, Filter, AGN_ID = AGN_ID, frac = 0.0):
    """ Display distribution of stars by number of epochs per star.
        Include how many stars are available with at least a certain fraction of the max number of epochs.
        Also show the instrumental magnitude of stars as a function of the number of epochs.
        
    Input:
    Star_File - Star data (pd dataframe).
    Filter - Filter (string).
    frac - (optional) print number of stars available with at least this fraction of the max number of epochs.
    """
    
    #Select specified filter data
    Star_Data0 = Locate_Star_Filter(Star_File, Filter)
    
    #Make sure AGN isn't included
    Star_IDs = [k for k in pd.unique(Star_Data0['id_apass']) if k!= AGN_ID]
    #Count number of epochs for each star and mean instrumental mags
    Lengths = []
    Mags = []
    for ID in Star_IDs:
        Star_Data = Locate_Star_ID(Star_Data0, ID)
        Lengths.append(len(Star_Data['MJD'].values))
        Mags.append(Generic_Optimal(Star_Data.mag_aper.values, Star_Data.err_aper.values)[0])
    Lengths = np.array(Lengths)
    
    # --------------- Plot data ----------------
    fig, (ax, axb) = plt.subplots(1,2, gridspec_kw={'width_ratios': [3, 2]}, figsize = (16, 6))
    fig.subplots_adjust(wspace = 0.25)
    nbins = 50
    hist0 = ax.hist(Lengths, bins = nbins, weights=np.ones(len(Lengths)) / len(Lengths))
    ax.cla()
    hist = ax.hist(Lengths, bins = nbins, color = 'gray', alpha = 0.7, zorder = 10, edgecolor = None, weights=np.ones(len(Lengths)) / max(hist0[0]) / len(Lengths))

    #CDF
    X0 = sorted(list(np.arange(min(Lengths), max(Lengths) + (max(Lengths) - min(Lengths)) / nbins, (max(Lengths) - min(Lengths)) / nbins))*2)[:-1]
    N = [0.]
    for i in hist0[0]:
        N.append(i+N[-1])
    N = sorted(N*2)[1:]
    X = sorted(X0)
    ax.plot(X, N, color = 'black', zorder = 20, lw = 1.0)
    
    #Number of Stars Available
    labs = []
    X2 = []
    N2 = []
    for j in np.arange(1.0, -0.1, -0.01):
        j = round(j, 3)
        X2.append(j*max(Lengths))
        N2.append(len([k for k in Lengths if k > j*max(Lengths)]) / len(Lengths))
        labs.append(str(len([k for k in Lengths if k > j*max(Lengths)])))
    N2 = sorted(N2*2)[1:]
    X2 = sorted(X2*2)[1:] ; X2 = sorted(X2, reverse = True)
    ax.plot(X2, N2, color = 'Teal', zorder = 20, lw = 1.0)
    
    # ------- Formatting -------
    ax.set_title(str(Filter) + ' Filter')    
    ax.set_xlabel('Number of Epochs')
    ax.set_ylabel('# of Stars Available with at least M datapoints \n M = Frac x max # of Epochs', color = 'Teal')
    ax.set_xlim(0, max(Lengths))
    ax.set_xticks(np.arange(0, max(Lengths), 100))
    ax.set_ylim(0, 1.01)
    ax.set_yticks(np.arange(0.0, 1.1, 0.1))
    ax.set_yticklabels([round(i) for i in np.arange(0, len(Lengths)+len(Lengths)/11, 0.1*len(Lengths))], color = 'Teal')
    ax.tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
    ax.tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
    ax.minorticks_on()
    
    ax2 = ax.twinx()
    ax2.plot(X, N, alpha = 0)
    ax2.set_ylim(0, 1.01)
    ax2.set_yticks(np.arange(0, 1.1, 0.1))
    ax2.set_yticklabels([round(j, 2) for j in np.arange(0, 1.1, 0.1)], color = 'black')
    ax2.set_ylabel('CDF', color = 'black')
    ax2.tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
    ax2.tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
    ax2.minorticks_on()
    
    ax3 = ax.twiny()
    ax3.vlines(np.arange(0,1.05,0.1), 0, 1,color = 'teal', lw = 0.8, ls = '--', alpha = 0.2, zorder = 0)
    ax3.vlines(np.arange(0,1.05,0.05), 0, 1,color = 'teal', lw = 0.5, ls = '--', alpha = 0.2, zorder = 0)
    ax3.hlines(np.arange(0,1.05,0.1), 0, 1,color = 'teal', lw = 0.8, ls = '--', alpha = 0.2, zorder = 0)
    ax3.hlines(np.arange(0,1.05,0.05), 0, 1,color = 'teal', lw = 0.5, ls = '--', alpha = 0.2, zorder = 0)
    ax3.set_xticks(np.arange(0, 1.1, 0.1))
    ax3.set_xticklabels([round(j, 2) for j in np.arange(0, 1.1, 0.1)], color = 'Teal')
    ax3.plot(X, N, alpha = 0)
    ax3.set_xlim(0, 1.01)
    ax3.tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
    ax3.tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
    ax3.minorticks_on()
    ax3.set_xlabel('Frac', color = 'Teal')
        
    axb.scatter(Lengths, Mags, color = 'orange')
    axb.set_xlabel('Number of Epochs')
    axb.set_ylabel('Mean Instrumental Mag')
    axb.invert_yaxis()
    axb.tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
    axb.tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
    axb.minorticks_on()
    
    if frac == 0.0:
        print('============================== ',str(Filter), 'Filter  ==============================')
        print('Total number of stars available: ', len(Lengths))
        print('Max number of epochs amongst all stars:', max(Lengths))
    else:
        print('============================== ',str(Filter), 'Filter  ==============================')
        print('Total number of stars available: ', len(Lengths))
        print('Max number of epochs amongst all stars:', max(Lengths))
        print('Number of stars available with more than ' + str(frac)+' of the max number of epochs: '+str(len([k for k in Lengths if k > frac*max(Lengths)])) )

def Detect_Var(Star_File, Filter, High = [], Label = False, Log = False, HighLabel = True, TEL = TEL, AGN_ID = AGN_ID):
    """Display rms vs mag plot to identify variable stars. *work in progress - do a fitting/binning*

    Input:
    Star_File - calibrated star dataframe (pd dataframe).
    Filter - Filter (string).
    High - (Optional) any stars you want to highlight better on the plot for clarity (list/array)
    """
    
    #Set up figure
    fig, ax = plt.subplots(1,1)
    
    #Put data on log plot if wanted
    if Log == True:
        ax.set_yscale('log')
    
    ax.set_title(Filter + ' Filter')
    #All stars in file
    Star_IDs = pd.unique(Star_File['id_apass'])
    
    for ID in Star_IDs:
        #Select data for particular star
        dat = Locate_Star_ID(Star_File, ID)
        mags = dat['m_star'].values[0]
        mags_err = dat['m_star_err'].values[0]
        
        #If you wanted to plot the extra variance
        #rms_star = dat['rms_star'].values[0]
        #rms_star_err = dat['rms_star_err'].values[0]
        
        #Compute rms
        rms_star = root_mean_squared_deviation(dat['mag_aper'], dat['m_star'])
        
        if ID in High:
            ax.scatter(mags, rms_star, color = 'deeppink', s = 25, marker = 'x', zorder = 10)
        else:
            ax.scatter(mags, rms_star, color = 'teal', s = 15)
        
        #Add star IDs nect to data points if wanted
        if Label == True:
            
            if Log == True:
                if ID in High:
                    t = ax.text(mags, rms_star, str(ID), rotation = 0, color = 'deeppink', alpha = 1.0, size = 10, zorder = 10, weight = 'bold')
                else:
                    ax.text(mags, rms_star, str(ID), rotation = 90, color = 'black', alpha = 0.5, size = 8)
            else:
                if ID in High:
                    t = ax.text(mags - 0.05, rms_star + 0.003, str(ID), rotation = 0, color = 'deeppink', alpha = 1.0, size = 10, zorder = 10, weight = 'bold')                        
                else:
                    ax.text(mags - 0.05, rms_star + 0.003, str(ID), rotation = 90, color = 'black', alpha = 0.5, size = 8)
        elif HighLabel == True:
            if Log == True:
                if ID in High:
                    t = ax.text(mags, rms_star, str(ID), rotation = 0, color = 'deeppink', alpha = 1.0, size = 10, zorder = 10, weight = 'bold')
            else:
                if ID in High:
                    t = ax.text(mags - 0.05, rms_star + 0.003, str(ID), rotation = 0, color = 'deeppink', alpha = 1.0, size = 10, zorder = 10, weight = 'bold')                        

    ax.set_xlabel('Mean Instrumental Magnitude')
    ax.set_ylabel('RMS Star')
    ax.tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
    ax.tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
    ax.minorticks_on()
    
def Plot_LC(Star_Data, Filter, TEL = TEL, err_th = 0.05, Stars = []):
    """Plot calibrated star lightcurves and total error distribution.
       Remove outliers based on a double outlier detection method and inspect cleaned lightcurves.
    
    Input:
    Star_Data - Calibrated star dataframe (pd dataframe).
    Filter - Filter (string).
    err_th - (Optional) Error threshold for clipping outliers.
    Stars - (Optional) If you don't wish to plot aLL the lightcurves for all the stars in the file, input IDs of stars you do want to see (list/array).
    """
    COLORS = ['navy', 'green', 'blue', 'orange', 'purple', 'yellow', 'skyblue', 'violet', 'darkgreen', 'maroon']
    
    if len(Stars) > 0:
        #Selecting only inputed stars
        Star_IDs = Stars
    else:
        #All stars in dataframe
        Star_IDs = pd.unique(Star_Data['id_apass'])
    
    for ID in Star_IDs:
        #set up figure
        fig0, ax0 = plt.subplots(5,2, sharex = False, figsize = (13, 15))
        fig0.subplots_adjust(wspace=0.1, hspace = 0.4)
        r = 0
        c = 0
        
        #older MAD method
        #scope_mad = []
        
        for scope in TEL:
            #Select data for particular telescope
            dat = Locate_Star_ID(Locate_Star_Scope(Star_Data, scope), ID)
            err = dat.err_tot
            
            if len(dat) > 0:
                nbins = int(0.8*len(err))
                if nbins == 0:
                    nbins = 1
                hist = ax0[r, c].hist(err, bins = nbins, alpha = 0.7, density = False, color = COLORS[TEL.index(scope)], edgecolor = 'black')
                ax0[r, c].vlines(err_th, 0, max(hist[0]), lw = 1.2, ls = '--', color = 'red', label = 'Outlier Threshold')
                
                #older MAD method
                #MAD = np.sum(abs(err - np.median(err))) / len(err)
                #ax0[r, c].axvspan(np.median(err) - MAD, np.median(err) + MAD, alpha=0.1, color='red',label = 'MAD')
                #ax0[r, c].axvspan(np.median(err) - 2*MAD, np.median(err) + 2*MAD, alpha=0.1, color='orange',label = '2*MAD')
                
                ax0[r, c].legend()
                ax0[r, c].tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
                ax0[r, c].tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
                ax0[r, c].minorticks_on()
                ax0[r, c].set_xlabel('Total error')
            ax0[r, c].set_title('Star ' + str(ID) + '  |  '+ str(scope) +'  |  No. of datapoints: ' + str(len(dat)))

            if c < 1:
                c = c + 1
            else:
                r = r + 1
                c = 0
            #older mad method
            #scope_mad.append(MAD)

        #Set up figure
        fig, ax4 = plt.subplots(2,1, figsize = (10, 6), sharex = True)
        fig.subplots_adjust(hspace = 0)
        ax4[0].set_ylabel('Mag')
        ax4[1].set_ylabel('Mag')
        ax4[1].set_xlabel('MJD')

        tot_dat = 0.
        clipped = 0.
        for scope in TEL:
            #Select data based on particular telescope
            dat = Locate_Star_Scope(Locate_Star_Filter(Locate_Star_ID(Star_Data, ID), Filter), scope)
            if len(dat) > 0:
                #older mad method
                #con1 = (dat['err_tot'] > np.median(dat['err_tot']) + scope_mad[TEL.index(scope)])

                #First outlier detection method based on specified threshold
                con1 = (dat['err_tot'] > err_th)
                #Secondary outlier detection method based on deviation from mean
                con2 = (abs(dat['m_star'] - dat['mag_aper']) > 3*dat['err_tot'])
                outliers = np.where(np.logical_or(con1,con2))[0]
                not_outliers = [j for j in np.arange(0, len(dat), 1) if j not in outliers]

                dat2a = dat.iloc[outliers]
                dat2b = dat.iloc[not_outliers]
                tot_dat = tot_dat + len(dat)
                clipped = clipped + len(outliers)

                ax4[0].scatter(dat2a['MJD'].values, dat2a['mag_aper'].values, color = 'red', s = 30, marker = 'x', zorder = 10, alpha = 1.0, linewidths=1.0)
                label = scope

                Data = Locate_Star_Scope(dat, scope)
                ax4[1].scatter(dat2b['MJD'].values, dat2b['mag_aper'].values, s = 20, zorder = 10, label = label, color = COLORS[TEL.index(scope)])
                ax4[1].errorbar(dat2b['MJD'].values, dat2b['mag_aper'].values, yerr = dat2b['err_tot'].values,ls = 'none', color = COLORS[TEL.index(scope)], lw = 0.5)
                ax4[0].errorbar(dat['MJD'].values, dat['mag_aper'].values, yerr = (dat['err_aper'].values**2 + dat['rms_sc'].values**2 + dat['rms_t'].values**2 + dat['rms_star'].values**2)**0.5, ls = 'none', color = COLORS[TEL.index(scope)], lw = 0.5)
                ax4[0].scatter(dat['MJD'].values, dat['mag_aper'].values, s = 20, label = label, color = COLORS[TEL.index(scope)])
        
        Data = Locate_Star_ID(Star_Data, ID)
        ax4[0].hlines(Data['m_star'].values[0], min(Data['MJD'].values), max(Data['MJD'].values), lw = 1, ls = '--', label = 'Star: ' + str(ID))
        ax4[0].legend(loc='center right', bbox_to_anchor=(1.15, 0.5), fontsize = 'x-small')
        ax4[0].invert_yaxis()
        ax4[0].tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
        ax4[0].tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
        ax4[0].minorticks_on()
        ax4[0].set_title('Star ' + str(ID) + '  |  Datapoints Clipped: '+str(int(clipped))+' out of '+str(int(tot_dat)))
        ax4[1].hlines(Data['m_star'].values[0], min(Data['MJD'].values), max(Data['MJD'].values), lw = 1, ls = '--', label = 'Star: ' + str(ID))
        ax4[1].legend(loc='center right', bbox_to_anchor=(1.15, 0.5), fontsize = 'x-small')
        ax4[1].invert_yaxis()
        ax4[1].tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
        ax4[1].tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
        ax4[1].minorticks_on()
        
def AGN_LC(Original_Star_File, Star_Data, Filter, err_th = 0.05, Plot = True, Rem_out = True, AGN_ID = AGN_ID, TEL = TEL, zp = [0.,0.]):
    """ Return Dataframe with calibrated AGN Lightcurve. Choose whether to leave in outliers.
        Can also plot the AGN lightcurve.
    
    Input:
    Original_Star_File - Original lco file (pd dataframe).
    Star_Data - Calibrated star dataframe (pd dataframe).
    Filter - Filter (string).
    err_th - (Optional) Error threshold for clipping outliers.
    Plot - (Optional) Choose whether to display the error distributions and lightcurves (Bool).
    Rem_out = (Optional) Choose whether to remove outliers from returned AGN dataframe (Bool). 
    
    Output:
    AGN_DF - dataframe containing calibrated AGN lightcurve.
   """
    
    #Set up dataframe
    columns = ['Filter','telid', 'MJD', 'mag', 'err', 'err_sys']
    AGN_DF = pd.DataFrame(columns=columns)
    
    AGN_Data = Locate_Star_ID(Locate_Star_Filter(Original_Star_File, Filter), AGN_ID)
    for scope in TEL:
        
        #Select data based on telescope
        AGN_Data2 = Locate_Star_Scope(AGN_Data, scope)

        for E in pd.unique(AGN_Data2['MJD']):
            
            #Select data based on Epoch
            AGN_Data3 = Locate_Star_Epoch(AGN_Data2, E)
            Star_Data2 = Locate_Star_Epoch(Locate_Star_Filter(Locate_Star_Scope(Star_Data, scope), Filter), E).head(1)
            
            #Assuming the AGN has a datapoint for that epoch...
            if len(Star_Data2) != 0:
                
                #Apply correction parameters
                AGN_mag = AGN_Data3['mag_aper'].values - Star_Data2['DMAGT'].values - Star_Data2['DMAGS'].values + zp[0]
                AGN_mag_err = (AGN_Data3['err_aper'].values[0]**2 + Star_Data2['rms_sc'].values[0]**2 + Star_Data2['rms_t'].values[0]**2)**0.5
                
                temp = [{'Filter': Filter, 'telid': scope, 'MJD': E, 'mag': AGN_mag[0], 'err': AGN_mag_err, 'err_sys': AGN_mag_err*0.0+zp[1]}]
                temp_DF = pd.DataFrame(temp)
                
                AGN_DF = pd.concat([AGN_DF, temp_DF], ignore_index=True)
                
                # AGN_DF = AGN_DF.append({'Filter': Filter, 'telid': scope, 'MJD': E, 'mag': AGN_mag[0], 'err': AGN_mag_err, 'err_sys': AGN_mag_err*0.0+zp[1]}, ignore_index = True)
    
    if Plot == True:
        print('======================================================== ' + Filter + ' Filter' + ' ========================================================')
        
        COLORS = ['navy', 'green', 'blue', 'orange', 'purple', 'yellow', 'skyblue', 'violet', 'darkgreen', 'maroon']
        
        #Set up figures
        fig0, ax0 = plt.subplots(5,2, sharex = False, figsize = (13, 15))
        fig0.subplots_adjust(wspace=0.1, hspace = 0.4)
        
        fig, ax = plt.subplots(1,1,figsize=(13, 6), sharex = False)
        ax.set_title('Uncalibrated AGN Lightcurve')
        ax.set_xlabel('MJD')
        ax.set_ylabel('Mag')
        ax.tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
        ax.tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
        ax.minorticks_on()
        ax.invert_yaxis()

        fig2, ax2 = plt.subplots(1,1,figsize=(13, 6))
        fig3, ax3 = plt.subplots(1,1,figsize=(13, 6))
        ax2.set_title('Calibrated AGN Lightcurve with outliers')
        ax2.set_ylabel('Mag')
        ax2.tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
        ax2.tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
        ax2.minorticks_on()
        ax2.invert_yaxis()
        
        clipped = len(AGN_DF.loc[AGN_DF['err'] > err_th])
        ax3.set_title('Calibrated AGN Lightcurve without outliers  |  '+ 'Datapoints Clipped: ' +str(int(clipped))+' out of '+str(len(AGN_DF))+'  |  Outlier error threshold: '+ str(round(100*err_th, 1)) + '%')
        ax3.set_ylabel('Mag')
        ax3.set_xlabel('MJD')
        ax3.tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
        ax3.tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
        ax3.minorticks_on()
        ax3.invert_yaxis()
        
        
        r = 0
        c = 0
        #Plot error distributions
        for scope in TEL:
            dat = Locate_Star_Scope(AGN_DF, scope)
            
            if len(dat) > 0:
                nbins = int(len(dat) / 2)
                if nbins == 0:
                    nbins = 1
                hist = ax0[r, c].hist(dat['err'].values, bins = nbins, alpha = 0.7, density = False, color = COLORS[TEL.index(scope)], edgecolor = 'black')
                ax0[r, c].vlines(err_th, 0, max(hist[0]), lw = 1.2, ls = '--', color = 'red', label = 'Outlier Threshold: '+ str(err_th))

                #Old mad selection
                #ax0[r, c].vlines(np.median(dat['err'].values), 0, max(hist[0]), lw = 1, ls = '--', color = 'maroon', label = 'Median: '+ str(round(np.median(dat['err'].values), 6)))
                #MAD = np.sum(abs(errors[TEL.index(scope)] - np.median(errors[TEL.index(scope)]))) / len(errors[TEL.index(scope)])
                #ax0[r, c].axvspan(np.median(dat['err'].values) - MAD, np.median(dat['err'].values) + MAD, alpha=0.1, color='red',label = 'MAD')
                #ax0[r, c].axvspan(np.median(dat['err'].values) - 2*MAD, np.median(dat['err'].values) + 2*MAD, alpha=0.1, color='orange',label = '2*MAD')

                ax0[r,c].legend()
                ax0[r,c].tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
                ax0[r,c].tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
                ax0[r,c].minorticks_on()
                ax0[r,c].set_title(str(scope) + '  |  No. of datapoints: '+ str(len(dat['err'].values)))
                ax0[r,c].set_xlabel('Total error')
    
            if c < 1:
                c = c + 1
            else:
                r = r + 1
                c = 0
            
            #Plot uncalibrated data
            ax.scatter(Locate_Star_Scope(AGN_Data, scope)['MJD'], Locate_Star_Scope(AGN_Data, scope)['mag_aper'], s = 5, label = scope, color = COLORS[TEL.index(scope)])
            
            #Plot calibrated data with errors
            ax2.scatter(dat['MJD'], dat['mag'], s = 5, color = COLORS[TEL.index(scope)], label = scope)
            ax2.errorbar(dat['MJD'], dat['mag'], yerr = dat['err'], lw = 0.5, color = COLORS[TEL.index(scope)], ls = 'none')
            
            #Plot calibrated data without errors
            dat2 = dat.loc[dat['err'] < err_th]
            ax3.scatter(dat2['MJD'], dat2['mag'], s = 5, color = COLORS[TEL.index(scope)], label = scope)
            ax3.errorbar(dat2['MJD'], dat2['mag'], yerr = dat2['err'], lw = 0.5, color = COLORS[TEL.index(scope)], ls = 'none')
            
        ax.legend()
        ax2.legend()
        ax3.legend()
    
    if Rem_out == True:
        print('Returning dataframe with outliers taken out')
        AGN_DF = AGN_DF.loc[AGN_DF['err'] < err_th]
        return AGN_DF
    else:
        print('Returning dataframe with outliers kept in')
        return AGN_DF
    
def AGN_LC2(Original_Star_File, Star_Data, Filter, err_th = 0.05, Plot = True, Rem_out = True, AGN_ID = AGN_ID, TEL = TEL, zp = [0.,0.]):
    """ Return Dataframe with calibrated AGN Lightcurve. Choose whether to leave in outliers.
        Can also plot the AGN lightcurve with shared axes.
    
    Input:
    Original_Star_File - Original lco file (pd dataframe).
    Star_Data - Calibrated star dataframe (pd dataframe).
    Filter - Filter (string).
    err_th - (Optional) Error threshold for clipping outliers.
    Plot - (Optional) Choose whether to display the error distributions and lightcurves (Bool).
    Rem_out = (Optional) Choose whether to remove outliers from returned AGN dataframe (Bool). 
    
    Output:
    AGN_DF - dataframe containing calibrated AGN lightcurve.
   """
    
    #Set up dataframe
    columns = ['Filter','telid', 'MJD', 'mag', 'err', 'err_sys']
    AGN_DF = pd.DataFrame(columns=columns)
    
    AGN_Data = Locate_Star_ID(Locate_Star_Filter(Original_Star_File, Filter), AGN_ID)
    for scope in TEL:
        
        #Select data based on telescope
        AGN_Data2 = Locate_Star_Scope(AGN_Data, scope)

        for E in pd.unique(AGN_Data2['MJD']):
            
            #Select data based on Epoch
            AGN_Data3 = Locate_Star_Epoch(AGN_Data2, E)
            Star_Data2 = Locate_Star_Epoch(Locate_Star_Filter(Locate_Star_Scope(Star_Data, scope), Filter), E).head(1)
            
            #Assuming the AGN has a datapoint for that epoch...
            if len(Star_Data2) != 0:
                
                #Apply correction parameters
                AGN_mag = AGN_Data3['mag_aper'].values - Star_Data2['DMAGT'].values - Star_Data2['DMAGS'].values + zp[0]
                AGN_mag_err = (AGN_Data3['err_aper'].values[0]**2 + Star_Data2['rms_sc'].values[0]**2 + Star_Data2['rms_t'].values[0]**2)**0.5
                
                temp = [{'Filter': Filter, 'telid': scope, 'MJD': E, 'mag': AGN_mag[0], 'err': AGN_mag_err, 'err_sys': AGN_mag_err*0.0+zp[1]}]
                temp_DF = pd.DataFrame(temp)
                
                AGN_DF = pd.concat([AGN_DF, temp_DF], ignore_index=True)
                
                # AGN_DF = AGN_DF.append({'Filter': Filter, 'telid': scope, 'MJD': E, 'mag': AGN_mag[0], 'err': AGN_mag_err, 'err_sys': AGN_mag_err*0.0+zp[1]}, ignore_index = True)
    
    if Plot == True:
        print('======================================================== ' + Filter + ' Filter' + ' ========================================================')
        
        COLORS = ['navy', 'green', 'blue', 'orange', 'purple', 'yellow', 'skyblue', 'violet', 'darkgreen', 'maroon']
        
        #Set up figures        
        fig, ax = plt.subplots(2, figsize=(13,12), sharex=True)
        plt.subplots_adjust(wspace=0, hspace=0)
        ax[0].tick_params(axis = 'y', which = 'major', direction = 'out', length = 6)
        ax[0].tick_params(axis = 'y', which = 'minor', direction = 'out', length = 4)
        ax[1].tick_params(axis = 'y', which = 'major', direction = 'out', length = 6)
        ax[1].tick_params(axis = 'y', which = 'minor', direction = 'out', length = 4)
        ax[1].tick_params(axis = 'x', which = 'major', direction = 'out', length = 6)
        ax[1].tick_params(axis = 'x', which = 'minor', direction = 'out', length = 4)
        ax[0].set_ylabel('Uncalibrated Mag')
        ax[0].set_xlabel('MJD')
        ax[1].set_ylabel('Calibrated Mag')
        ax[1].set_xlabel('MJD')        
        ax[0].minorticks_on()
        ax[0].invert_yaxis()
        ax[1].minorticks_on()
        ax[1].invert_yaxis()
        
        
        r = 0
        c = 0
        #Plot error distributions
        for scope in TEL:
            dat = Locate_Star_Scope(AGN_DF, scope)
            
            
    
            if c < 1:
                c = c + 1
            else:
                r = r + 1
                c = 0
            
            #Plot uncalibrated data
            ax[0].scatter(Locate_Star_Scope(AGN_Data, scope)['MJD'], Locate_Star_Scope(AGN_Data, scope)['mag_aper'], s = 5, label = scope, color = COLORS[TEL.index(scope)])
            
            #Plot calibrated data without errors
            dat2 = dat.loc[dat['err'] < err_th]
            ax[1].scatter(dat2['MJD'], dat2['mag'], s = 5, color = COLORS[TEL.index(scope)], label = scope)
            ax[1].errorbar(dat2['MJD'], dat2['mag'], yerr = dat2['err'], lw = 0.5, color = COLORS[TEL.index(scope)], ls = 'none')
            
        ax[0].legend(ncol=3, bbox_to_anchor=(0.8, -1.2), fontsize=20)
        
    plt.savefig(fr"{Filter}_agnlc", dpi=300, bbox_inches='tight')
    
    if Rem_out == True:
        print('Returning dataframe with outliers taken out')
        AGN_DF = AGN_DF.loc[AGN_DF['err'] < err_th]
        return AGN_DF
    else:
        print('Returning dataframe with outliers kept in')
        return AGN_DF
    
def Plot_by_year(AGN_dataframe, year):
    """ Plot particular year of AGN data for closer visual inspection.
        *Note: MJD boundaries may need to be adjusted here for individual AGNs.
    
    Input:
    AGN_dataframe - dataframe containing calibrated AGN data from AGN_LC function (pd dataframe).
    year - year of data to plot (int)
    """
    
    COLORS = ['navy', 'orange', 'yellow', 'green', 'blue', 'indigo', 'violet', 'pink', 'brown', 'gray']
    
    #Ref dates may need to be changed depending on specific AGN monitoring
    if year == 1:
        ref_date1 = 0
        ref_date2 = 57754
    elif year == 2:
        ref_date1 = 57754
        ref_date2 = 58119
    elif year == 3:
        ref_date1 = 58119
        ref_date2 = 10**(10)

    fig, ax = plt.subplots(len(FILTERS),1, figsize = (15, 10), sharex = True)
    fig.subplots_adjust(wspace=0.1, hspace = 0.0)
    
    ax[0].set_title('Year ' + str(year), size = 15)
    ax[-1].set_xlabel('MJD')
    for Filter in FILTERS:
        dat0 = Locate_Star_Filter(AGN_dataframe, Filter)
        dat = dat0.loc[(dat0['MJD'] > ref_date1) & (dat0['MJD'] < ref_date2)]
    
        ax[FILTERS.index(Filter)].scatter(dat.MJD, dat.mag, s = 10, color = 'Teal', zorder = 0, label = Filter)
        ax[FILTERS.index(Filter)].errorbar(dat.MJD, dat.mag, yerr = dat.err, ls = 'none', color = 'Teal', zorder = 0)
        ax[FILTERS.index(Filter)].invert_yaxis()
        ax[FILTERS.index(Filter)].legend()
        ax[FILTERS.index(Filter)].set_ylim(min(dat.mag) - 0.05, max(dat.mag) + 0.05)
        ax[FILTERS.index(Filter)].vlines(np.arange(min(dat['MJD'].values), max(dat['MJD'].values) + 10, 10), min(dat['mag'].values) - 0.3, max(dat['mag'].values) + 0.3, color = 'black', ls = '--', lw = 0.5, alpha = 0.8)
        ax[FILTERS.index(Filter)].tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
        ax[FILTERS.index(Filter)].tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
        ax[FILTERS.index(Filter)].minorticks_on()
        ax[FILTERS.index(Filter)].set_ylabel('mag')
        ax[FILTERS.index(Filter)].invert_yaxis()

def Check_RMS_MAG(Star_File, Filter = FILTERS[0]):
    """Compute the mean magnitude & rms of initial star data, calibrated using zeropoints, and plot rms vs mag to 
       check AGN ID.
    Inputs:
    Star_File - Original file with zeropoints (pd dataframe).
    Filter - (Optional) Filter (str).
    """
    
    #Set up figure
    fig, ax = plt.subplots(1,1, figsize = (5, 5))
    #Select data based on filter
    dat = Locate_Star_Filter(Star_File, Filter)
    Star_IDs = pd.unique(dat['id_apass'])
    
    for ID in Star_IDs:
        #Select data based on star ID
        dat2 = Locate_Star_ID(dat, ID)
        mag = dat2.mag_aper.values + dat2.zp.values
        
        #Compute mean magnitude
        mean = np.mean(mag)
        #Compute rms
        rms = root_mean_squared_deviation(mag, mean)
        ax.scatter(mean - 0.2, rms, color = 'lightcoral', edgecolor = 'black', linewidth = 1,  s = 25)
        ax.text(mean, rms, str(ID), rotation = 90, size = 10, alpha = 0.7)
    ax.set_yscale('log')
    ax.set_xlabel('mag', size = 15)
    ax.set_ylabel('rms', size = 15)
    ax.tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
    ax.tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
    ax.minorticks_on()
    
def Check_LC(Star_File, ID, Filter = FILTERS[0]):
    """Plot lightcurve of chosen ID, using original data with zeropoints, to help identify the AGN.
    Inputs:
    Star_File - Original file with zeropoints (pd dataframe).
    ID - ID of potential AGN (float)
    Filter - (Optional) Filter (str).
    """
    #Set up figure
    fig, ax = plt.subplots(1,1, figsize = (12, 4))
    
    #Select data based on ID and Filter
    dat = Locate_Star_ID(Locate_Star_Filter(Star_File, Filter), ID)
    mag = dat.mag_aper.values + dat.zp.values
    MJD = dat.MJD
    ax.set_title('Star ID: '+ str(ID))
    ax.scatter(MJD, mag, s = 15, color = 'lightcoral', edgecolor = 'black', linewidth = 1)
    ax.set_xlabel('MJD', size = 15)
    ax.set_ylabel('Mag', size = 15)
    ax.tick_params(axis = 'both', which = 'major', direction = 'in', length = 6)
    ax.tick_params(axis = 'both', which = 'minor', direction = 'in', length = 4)
    ax.minorticks_on()
    ax.invert_yaxis()

def Convergence_Plot(Traces, mode, parameter, parameter_err = None, lim = 20, zoom = 0):
    """ Plots of traces of intercalibration algorithm to visualy check convergence.
    
    Input:
    Traces - Traces output from Corr() (list).
    mode - 'double' or 'single' (str): 
            double: Plots the trace of a parameter AND its error as a shell around the plot.
            single: Plots any ONE parameter from the list above. Can be used to check the convergence of the errors themselves for example.
    parameter - parameter to check convergence for (str).
    parameter_err - (Optional) if mode = 'double', parameter error (str).
    lim - (Optional) number of plots to show for epoch & star_mag related parameters as there can be a lot of them (int).
    zoom - (Optional) zoom into convergence tail to check finer fluctuations (int).
    """
    
    #Possible parameters to check
    names = ['dmagt', 'dmagt_err', 'rms_t', 'rms_t_err', 'DMAGT', 'DMAGT_err', 'dmags', 'dmags_err', 'rms_sc',
             'rms_sc_err', 'm_star', 'm_star_err', 'rms_star', 'rms_star_err', 'DMAGS', 'DMAGS_err']
    scope_pars = ['dmags', 'dmags_err', 'rms_sc','rms_sc_err', 'DMAGS', 'DMAGS_err']
    time_pars = ['dmagt', 'dmagt_err', 'rms_t','rms_t_err', 'DMAGT', 'DMAGT_err']
    star_pars = ['m_star', 'm_star_err', 'rms_star', 'rms_star_err']
    iterations = len(Traces[-1])
    
    if mode == 'single':
        par = Traces[names.index(parameter)]
            
        if parameter in scope_pars:
            COLORS = ['red', 'green', 'blue', 'purple', 'black', 'gray', 'violet', 'teal', 'orange', 'yellow']

            #Set up plot
            fig,ax = plt.subplots(2,5, sharex = True, figsize = (20, 10))
            fig.subplots_adjust(wspace=0.3, hspace = 0)

            #Iteration number
            IT = np.arange(1, iterations + 1, 1)[zoom:]
            r = 0
            c = 0
            for j in range(len(TEL)):
                D = []
                for i in range(len(par)):
                    D.append(par[i][j])
                D = D[zoom:]

                ax[r, c].plot(IT, D, label = 'Scope: ' + str(TEL[j]), color = COLORS[j], lw = 0.8)
                ax[r, c].legend()

                if r == 1:
                    ax[r, c].set_xlabel('Iteration')
                if c == 0:
                    ax[r,c].set_ylabel(parameter)

                if c < 4:
                    c = c + 1
                else:
                    c = 0
                    r = r + 1
        
        elif parameter in time_pars:
            SCOPE = 0
            IT = np.arange(1, iterations + 1, 1)[zoom:]
            r = 0
            c = 0            
            if lim < 4:
                lim = 4
            fig, ax = plt.subplots(math.ceil(lim / 4), 4, sharex = True, figsize = (30,20))
            fig.subplots_adjust(wspace=0.2, hspace = 0)
            for i in range(lim):
                D = []
                for j in range(len(par)):
                    D.append(par[j][SCOPE][i])
                D = D[zoom:]
                ax[r, c].plot(IT, D, label = 'MJD: ' + str(i), lw = 0.8)
                ax[r, c].legend()
                ax[r, c].ticklabel_format(useOffset=False)

                if r > math.ceil(lim / 4) - 4:
                    ax[r, c].set_xlabel('Iteration')
                if c == 0:
                    ax[r,c].set_ylabel(parameter)

                if c < 3:
                    c = c + 1
                else:
                    c = 0
                    r = r + 1

        elif parameter in star_pars:
            IT = np.arange(1, iterations + 1, 1)[zoom:]
            r = 0
            c = 0

            if lim < 5:
                lim = 5

            fig,ax = plt.subplots(math.ceil(lim / 5), 5, sharex = True, sharey = False, figsize = (20, 10))
            fig.subplots_adjust(wspace=0.4, hspace = 0)
            for j in range(lim):
                D = []
                for i in range(len(par)):
                    D.append(par[i][j])
                D = D[zoom:]
                ax[r, c].plot(IT, D, lw = 0.8)
                ax[r, c].ticklabel_format(useOffset=False)

                if r > math.ceil(lim / 5) - 5:
                    ax[r, c].set_xlabel('Iteration')
                if c == 0:
                    ax[r,c].set_ylabel(parameter)

                if c < 4:
                    c = c + 1
                else:
                    c = 0
                    r = r + 1         
    
    elif mode == 'double':
        
        if parameter_err == None:
            raise ValueError("With 'double' mode, par and par_err must be specified.")
        
        else:
            par = Traces[names.index(parameter)]
            par_err = Traces[names.index(parameter_err)]
            
            if parameter in scope_pars:
                COLORS = ['red', 'green', 'blue', 'purple', 'black', 'gray', 'violet', 'teal', 'orange', 'yellow']

                #Set up plot
                fig,ax = plt.subplots(2,5, sharex = True, figsize = (20, 10))
                fig.subplots_adjust(wspace=0.3, hspace = 0)

                #Iteration number
                IT = np.arange(1, iterations + 1, 1)[zoom:]
                r = 0
                c = 0
                for j in range(len(TEL)):
                    D = []
                    D_err = []
                    for i in range(len(par)):
                        D.append(par[i][j])
                        D_err.append(par_err[i][j])
                    D = D[zoom:]
                    D_err = D_err[zoom:]
                    
                    ax[r, c].plot(IT, D, label = 'Scope: ' + str(TEL[j]), color = COLORS[j], lw = 0.8)
                    ax[r, c].legend()

                    ax[r, c].plot(IT, np.array(D) + np.array(D_err), color = COLORS[j], lw = 0.8, ls = '--')
                    ax[r, c].plot(IT, np.array(D) - np.array(D_err), color = COLORS[j], lw = 0.8, ls = '--')
                    ax[r, c].fill_between(IT, np.array(D) - np.array(D_err), np.array(D) + np.array(D_err), color = COLORS[j], alpha = 0.3)

                    if r == 1:
                        ax[r, c].set_xlabel('Iteration')
                    if c == 0:
                        ax[r,c].set_ylabel(parameter)

                    if c < 4:
                        c = c + 1
                    else:
                        c = 0
                        r = r + 1
            elif parameter in time_pars:
                SCOPE = 0
                IT = np.arange(1, iterations + 1, 1)[zoom:]
                r = 0
                c = 0            
                if lim < 4:
                    lim = 4
                fig, ax = plt.subplots(math.ceil(lim / 4), 4, sharex = True, figsize = (30,20))
                fig.subplots_adjust(wspace=0.2, hspace = 0)
                for i in range(lim):
                    D = []
                    D_err = []
                    for j in range(len(par)):
                        D.append(par[j][SCOPE][i])
                        D_err.append(par_err[j][SCOPE][i])
                    D = D[zoom:]
                    D_err = D_err[zoom:]
                    ax[r, c].plot(IT, D, label = 'MJD: ' + str(i), lw = 0.8)
                    ax[r, c].legend()
                    ax[r, c].ticklabel_format(useOffset=False)

                    ax[r, c].plot(IT, np.array(D) + np.array(D_err), lw = 0.8, ls = '--', color = 'blue')
                    ax[r, c].plot(IT, np.array(D) - np.array(D_err), lw = 0.8, ls = '--', color = 'blue')
                    ax[r, c].fill_between(IT, np.array(D) - np.array(D_err), np.array(D) + np.array(D_err), alpha = 0.3)

                    if r > math.ceil(lim / 4) - 4:
                        ax[r, c].set_xlabel('Iteration')
                    if c == 0:
                        ax[r,c].set_ylabel(parameter)

                    if c < 3:
                        c = c + 1
                    else:
                        c = 0
                        r = r + 1

            elif parameter in star_pars:
                IT = np.arange(1, iterations + 1, 1)[zoom:]
                r = 0
                c = 0

                if lim < 5:
                    lim = 5

                fig,ax = plt.subplots(math.ceil(lim / 5), 5, sharex = True, sharey = False, figsize = (20, 10))
                fig.subplots_adjust(wspace=0.4, hspace = 0)
                for j in range(lim):
                    D = []
                    D_err = []
                    for i in range(len(par)):
                        D.append(par[i][j])
                        D_err.append(par_err[i][j])
                    D = D[zoom:]
                    D_err = D_err[zoom:]
                    ax[r, c].plot(IT, D, lw = 0.8)
                    ax[r, c].ticklabel_format(useOffset=False)

                    ax[r, c].plot(IT, np.array(D) + np.array(D_err), lw = 0.8, ls = '--')
                    ax[r, c].plot(IT, np.array(D) - np.array(D_err), lw = 0.8, ls = '--')
                    ax[r, c].fill_between(IT, np.array(D) - np.array(D_err), np.array(D) + np.array(D_err), alpha = 0.3)

                    if r > math.ceil(lim / 5) - 5:
                        ax[r, c].set_xlabel('Iteration')
                    if c == 0:
                        ax[r,c].set_ylabel(parameter)

                    if c < 4:
                        c = c + 1
                    else:
                        c = 0
                        r = r + 1
    else:
        raise ValueError(str(mode) + " is not a valid mode input. Please select either 'single' or 'double'.")

def Corr(Star_File, Filter, MAX_LOOPS = 100, bad_IDs = [], safe = 0.5, frac = 0.3, TEL = TEL, AGN_ID = AGN_ID, Star_Lim = None):
    """Algorithm to compute telescope correction parameters, star magnitudes, and extra variance as well as their errors.
    
    Input:
    Star_File - Uncalibrated Star data (pd dataframe).
    Filter - Filter (String).
    MAX_LOOPS - (Optional) Max number of loops if convergence isn't reached (int).
    bad_IDs - (Optional) Star IDs to omit in the calibration, primarily meant for variable stars (list/array).
    safe - (Optional) Safety step size to avoid overstepping (float).
    frac - (Optional) Fraction of stars with number of datapoints above max number of epochs (float).
    TEL - Telescope list (list/array)
    AGN_ID - AGN ID (float)
    
    Output:
    df - dataframe containing final correction parameters, extra variances, their errors and their corrected magnitudes (pd dataframe)
    TRACES - Traces of the correction parameters (arrays)
    """
    # ======================= Select stars to use in calibration ========================
    
    start_time = time.time()

    #Select stars based on fraction of datapoints
    Star_IDs = Brightest_Reduced(Star_File, Filter, frac = frac)
    
    #Further select stars based on variability
    if len(bad_IDs) > 0:
        Star_IDs = [i for i in Star_IDs if i not in bad_IDs]
    
    #Max number of stars to use
    if Star_Lim != None and len(Star_IDs) > Star_Lim:
        Star_IDs = Star_IDs[:Star_Lim]
    
    #Total number of stars used in calibration
    Nstars = len(Star_IDs)
    print('No. of Stars: '+ str(Nstars))
    
    #Select star data based on filter and star IDs
    Star_Data = Locate_Star_Filter(Locate_Star_ID(Star_File, Star_IDs), Filter)
    
    # ====================== Dataframes to return at the end ============================
    
    #Length of correction parameter dataframe
    df_length = len(Star_Data['id_apass'].values)
    
    #'scope_mean' and 'scope_mean_err' are temporary and will not be returned in the end, just used to speed up computation
    df = pd.DataFrame({'id_apass': Star_Data['id_apass'].values, 'Filter': Star_Data['Filter'].values,
                       'MJD': Star_Data['MJD'].values, 'telid': Star_Data['telid'].values, 
                       'scope_mean': df_length*[0.], 'scope_mean_err': df_length*[0.],
                       'mag_aper': Star_Data['mag_aper'].values, 'err_aper': Star_Data['err_aper'].values,
                       'dmagt': df_length*[0.], 'dmagt_err': df_length*[0.], 'rms_t': df_length*[0.],
                       'rms_t_err': df_length*[0.], 'dmags': df_length*[0.], 'dmags_err': df_length*[0.],
                       'rms_sc':df_length*[0.], 'rms_sc_err': df_length*[0.], 'm_star': df_length*[0.],
                       'm_star_err': df_length*[0.], 'rms_star': df_length*[0.], 'rms_star_err': df_length*[0.],
                       'err_tot': Star_Data['err_aper'].values, 'DMAGT': df_length*[0.], 'DMAGS': df_length*[0.],
                       'DMAGT_err': df_length*[0.], 'DMAGS_err': df_length*[0.],
                       'airmass': Star_Data['airmass'].values, 'seeing': Star_Data['seeing'].values}) 
    
    df.set_index(Star_Data.index.values)
    
    # ==================================== Traces ====================================
    
    epochs = []
    dmagts = [] ; dmagts_err = []
    rms_ts = [] ; rms_ts_err = []
    dmagss = [] ; dmagss_err = []
    rms_scs = [] ; rms_scs_err = []
    mstars = [] ; mstars_err = []
    rms_stars = [] ; rms_stars_err = []
    DMAGTS = [] ; DMAGTS_err = []
    DMAGSS = [] ; DMAGSS_err = []
    
    # ============================== Intercalibration =================================
    
    star_mag_correction = 0.
    loop = 0
    
    while loop < MAX_LOOPS:
        # ---------------------------------  1.dmagt  -------------------------------
        if __name__ == "__main__":
            df = process_dataframe_parallel2(df, loop, safe, Star_IDs)
 
        # ---------------------------------- 2.dmags --------------------------------
        #Compute temporary mean of star using optimal average
        for ID in Star_IDs:
            
            #Select data for particular telescope
            Star_Data_S = Locate_Star_ID(df, ID)
            Mag = Star_Data_S['mag_aper'].values
            Err = Star_Data_S['err_tot'].values
            Mean, Mean_err = Generic_Optimal(Mag, Err)
            
            #Update dataframe
            df['m_star'] = np.where((df['id_apass'] == ID), Mean, df['m_star'])
            df['m_star_err'] = np.where((df['id_apass'] == ID), Mean_err, df['m_star_err'])
        
        if __name__ == "__main__":
            df = process_dataframe_parallel3(df, loop, safe, Star_IDs)
            
        
        dat = df
        initial_DMAGS = (dat.DMAGS + dat.dmags).values
        DMAGS_mean, DMAGS_mean_err = Generic_Optimal(initial_DMAGS, dat['dmags_err'].values)
        Corr1 = dat.dmags - DMAGS_mean
        
        #Propagate error in new dmagt accordingly
        CorrE = (dat.dmags_err**2 + DMAGS_mean_err**2)**0.5
        
        #Update dataframe
        df.loc[Corr1.index, 'dmags'] = Corr1
        df.loc[CorrE.index, 'dmags_err'] = CorrE
        df.loc[CorrE.index, 'DMAGS_err'] = CorrE
        
        #Updated dataframe
        dat = df
        Corr2 = dat.DMAGS + dat.dmags
        
        #Update dataframe
        df.loc[Corr2.index, 'DMAGS'] = Corr2
        
        # ---------------------------------- 3.m(*) --------------------------------
        
        if __name__ == "__main__":
            df = process_dataframe_parallel4(df, loop, safe)
        
        first_star_corr = Locate_Star_ID(df, Star_IDs[0])['m_star'].values[0]
        df['mag_aper'] = df['mag_aper'].values - first_star_corr
        df['m_star'] = df['m_star'].values - first_star_corr

        #Keep track of how much the data shifts in total
        star_mag_correction = star_mag_correction + first_star_corr
        
        #Update traces
        epochs_temp = []
        dmagts_temp = [] ; dmagts_err_temp = []
        rms_t_temp = [] ; rms_t_err_temp = []
        DMAGTS_temp = [] ; DMAGTS_err_temp = []
        dmagss_temp = [] ; dmagss_err_temp = []
        DMAGSS_temp = [] ; DMAGSS_err_temp = []
        rms_sc_temp = [] ; rms_sc_err_temp = []
        
        for Telescope in TEL:
            Star_Data_T = Locate_Star_Scope(df, Telescope)
            dat = Star_Data_T.sort_values('MJD')
            
            unique_epochs = pd.unique(dat['MJD'])
            all_epochs = dat['MJD'].values
            unique_epoch_ind = np.array([list(all_epochs).index(e) for e in unique_epochs])
            
            epochs_temp.append(dat['MJD'].values[unique_epoch_ind])
            dmagts_temp.append(dat['dmagt'].values[unique_epoch_ind]) ; dmagts_err_temp.append(dat['dmagt_err'].values[unique_epoch_ind])
            rms_t_temp.append(dat['rms_t'].values[unique_epoch_ind]) ; rms_t_err_temp.append(dat['rms_t_err'].values[unique_epoch_ind])
            DMAGTS_temp.append(dat['DMAGT'].values[unique_epoch_ind]) ; DMAGTS_err_temp.append(dat['DMAGT_err'].values[unique_epoch_ind])
            

            dmagss_temp.append(dat['dmags'].values[0]) ; dmagss_err_temp.append(dat['dmags_err'].values[0])
            DMAGSS_temp.append(dat['DMAGS'].values[0]) ; DMAGSS_err_temp.append(dat['DMAGS_err'].values[0])
            rms_sc_temp.append(dat['rms_sc'].values[0]) ; rms_sc_err_temp.append(dat['rms_sc_err'].values[0])
            
        epochs.append(np.array(epochs_temp, dtype = object))
        dmagts.append(np.array(dmagts_temp, dtype = object)) ; dmagts_err.append(np.array(dmagts_err_temp, dtype = object))
        rms_ts.append(np.array(rms_t_temp, dtype = object)) ; rms_ts_err.append(np.array(rms_t_err_temp, dtype = object))
        DMAGTS.append(np.array(DMAGTS_temp, dtype = object)) ; DMAGTS_err.append(np.array(DMAGTS_err_temp, dtype = object))
        dmagss.append(np.array(dmagss_temp, dtype = object)) ; dmagss_err.append(np.array(dmagss_err_temp, dtype = object))
        DMAGSS.append(np.array(DMAGSS_temp, dtype = object)) ; DMAGSS_err.append(np.array(DMAGSS_err_temp, dtype = object))
        rms_scs.append(np.array(rms_sc_temp, dtype = object)) ; rms_scs_err.append(np.array(rms_sc_err_temp, dtype = object))
        
        mstars_temp = [] ; mstars_err_temp = []
        rms_stars_temp = [] ; rms_stars_err_temp = []
        for ID in Star_IDs:
            dat = Locate_Star_ID(df, ID)
            mstars_temp.append(dat['m_star'].values[0]) ; mstars_err_temp.append(dat['m_star_err'].values[0])
            rms_stars_temp.append(dat['rms_star'].values[0]) ; rms_stars_err_temp.append(dat['rms_star_err'].values[0])
        mstars.append(np.array(mstars_temp)) ; mstars_err.append(np.array(mstars_err_temp))
        rms_stars.append(np.array(rms_stars_temp)) ; rms_stars_err.append(np.array(rms_stars_err_temp))

        TRACES = [dmagts, dmagts_err, rms_ts, rms_ts_err, DMAGTS, DMAGTS_err, dmagss, dmagss_err, rms_scs, rms_scs_err, mstars, mstars_err, rms_stars, rms_stars_err, DMAGSS, DMAGSS_err]

        #Check convergence
        tiny = 1*10**(-4)
        All_par_diff = []
        if loop > 0:
            for i in np.arange(0, len(TRACES) - 1, 2):
                tr = TRACES[i]
                tr_err = TRACES[i+1]
                try:
                    All_par_diff.extend((abs((np.concatenate(tr[-1]).ravel()) - (np.concatenate(tr[-2]).ravel())) / np.concatenate(tr_err[-1]).ravel()))
                except:
                    All_par_diff.extend((abs(np.array(tr[-1]) - np.array(tr[-2])) / np.array(tr_err[-1])))
            #print(max(All_par_diff))
            if max(All_par_diff) < tiny:
                print('Iteration ' + str(loop + 1) + '/' + str(MAX_LOOPS) + ' Finished')
                df['m_star'] = df['m_star'] + star_mag_correction
                df['mag_aper'] = df['mag_aper'] + star_mag_correction
                print('Total run time:', time.time() - start_time)
                break
        print('Iteration ' + str(loop + 1) + '/' + str(MAX_LOOPS) + ' Finished')
        loop = loop + 1
    if loop == MAX_LOOPS:
        df['m_star'] = df['m_star'] + star_mag_correction
        df['mag_aper'] = df['mag_aper'] + star_mag_correction
        print('Total run time:', time.time() - start_time)
    return(df, TRACES)

def dmagt(object_data, loop, safe, Star_IDs):
    object_data = object_data
    #Compute avg and sig(avg) magnitude of the telescope for each star using optimal average
    for ID in Star_IDs:
        #Select data for particular star
        Star_Data_S = Locate_Star_ID(object_data, ID)
        Mag = Star_Data_S['mag_aper'].values
        Err = Star_Data_S['err_tot'].values
        Mean_Scope_Mag, Mean_Scope_Mag_Err = Generic_Optimal(Mag, Err)
        
        #Update dataframe
        object_data['scope_mean'] = np.where((object_data['id_apass'] == ID), Mean_Scope_Mag, object_data['scope_mean'])
        object_data['scope_mean_err'] = np.where((object_data['id_apass'] == ID), Mean_Scope_Mag_Err, object_data['scope_mean_err'])
    
    Unique_MJDs = pd.unique(object_data['MJD'])     
    Telescope = object_data.telid.values[0]
    if __name__ == "__main__":
        # Process the DataFrame in parallel
        result_df = process_dataframe_parallel(object_data, loop, safe, Telescope)

    #Ensure average of the dmagt's is zero to avoid degeneracy
    Star_Data_T = result_df
    initial_DMAGT = Star_Data_T['DMAGT'].values + Star_Data_T['dmagt'].values
    DMAGT_mean, DMAGT_mean_err = Generic_Optimal(initial_DMAGT, Star_Data_T['dmagt_err'].values)
    Corr1 = Star_Data_T.dmagt - DMAGT_mean
    
    #Propagate error in new dmagt accordingly
    CorrE = (Star_Data_T.dmagt_err**2 + DMAGT_mean_err**2)**0.5
    
    #Update dataframe
    result_df.loc[Corr1.index, 'dmagt'] = Corr1
    result_df.loc[CorrE.index, 'dmagt_err'] = CorrE
    result_df.loc[CorrE.index, 'DMAGT_err'] = CorrE
    
    #Updated dataframe
    Star_Data_T = result_df
    Corr2 = Star_Data_T.dmagt + Star_Data_T.DMAGT
    
    #Update dataframe
    result_df.loc[Corr2.index, 'DMAGT'] = Corr2
    
    return(result_df)
    
def dmags(object_data, loop, safe, Star_IDs):
    #Select data for particular telescope
    object_data = object_data
    Shifted_Mags = (object_data.mag_aper - object_data.dmagt - object_data.m_star).values
    
    #Propagate errors accordingly
    Shifted_Mags_Err = ((object_data.err_tot**2 + object_data.dmagt_err**2 + object_data.m_star_err**2)**0.5).values
    
    #Compute correction parameters for that telescope
    dmags, dmags_err, rms_sc, rms_sc_err = arx.arx(Shifted_Mags, Shifted_Mags_Err, 1000)
    Telescope = object_data.telid.values[0]
    #Take safe size step to prevent overstepping
    if loop > 0:
        #old_dmags = df.loc[df['telid'] == TEL[i]]['dmags'].values[0]
        old_rms_sc = object_data['rms_sc'].values[0]
    
        #dmags = old_dmags + safe*(dmags - old_dmags)
        rms_sc = old_rms_sc + safe*(rms_sc - old_rms_sc)
    
    #Update dataframe
    object_data['dmags'] = np.where((object_data['telid'] == Telescope), dmags, object_data['dmags'])
    object_data['dmags_err'] = np.where((object_data['telid'] == Telescope), dmags_err, object_data['dmags_err'])
    object_data['rms_sc'] = np.where((object_data['telid'] == Telescope), rms_sc, object_data['rms_sc'])             
    object_data['rms_sc_err'] = np.where((object_data['telid'] == Telescope), rms_sc_err, object_data['rms_sc_err'])
    object_data['DMAGS_err'] = np.where((object_data['telid'] == Telescope), dmags_err, object_data['DMAGS_err'])
    
    #Update total error of datapoint with new sc_rms
    err_tot = (object_data.err_aper**2 + object_data.rms_star**2 + object_data.rms_t**2 + rms_sc**2)**0.5
    
    #Update dataframe
    object_data.loc[err_tot.index, 'err_tot'] = err_tot
    
    return object_data

def star_mags(object_data, loop, safe):
    object_data = object_data
    
    #Apply correction parameters
    Shifted_Mag = object_data['mag_aper'].values - object_data['dmagt'].values - object_data['dmags'].values
    
    #Propagate errors accordingly
    Shifted_Mag_Err = (object_data['err_tot'].values**2 + object_data['dmagt_err'].values**2 + object_data['dmags_err'].values**2)**0.5
    
    #Compute mean of star
    m_star, m_star_err, rms_star, rms_star_err = arx.arx(Shifted_Mag, Shifted_Mag_Err, 1000)
    
    #Take safe size step to prevent overstepping
    if loop > 0:
        old_mstar = object_data['m_star'].values[0]
        old_rms_star = object_data['rms_star'].values[0]
        
        m_star = old_mstar + safe*(m_star - old_mstar)
        rms_star = old_rms_star + safe*(rms_star - old_rms_star)
    
    #Update dataframe
    object_data['m_star'] = m_star
    object_data['m_star_err'] = m_star_err
    object_data['rms_star'] = rms_star           
    object_data['rms_star_err'] = rms_star_err
    
    #Update total error of datapoint with new star_rms
    err_tot = (object_data.err_aper**2 + object_data.rms_sc**2 + object_data.rms_t**2 + rms_star**2)**0.5
    Shift = object_data.mag_aper - object_data.dmagt - object_data.dmags
    
    #Update dataframe
    object_data.loc[err_tot.index, 'err_tot'] = err_tot
    object_data.loc[Shift.index, 'mag_aper'] = Shift
    
    return object_data
def process_object_data(object_data, loop, safe, Telescope):
    object_data = object_data
    #Shift mags by telescope means
    Shifted_Mags = (object_data.mag_aper - object_data.scope_mean).values

    #Propagate errors accordingly
    Shifted_Mags_Err = ((object_data.err_tot**2 + object_data.scope_mean_err**2)**0.5).values

    #Compute correction parameters for that MJD
    dmagt, dmagt_err, rms_t, rms_t_err = arx.arx(Shifted_Mags, Shifted_Mags_Err, 1000)
    
    #Take safe size step to prevent overstepping
    if loop > 0:
        old_dmagt = object_data['dmagt'].values[0]
        old_rms_t = object_data['rms_t'].values[0]
        
        dmagt = old_dmagt + safe*(dmagt - old_dmagt)
        rms_t = old_rms_t + safe*(rms_t - old_rms_t)
    
    E = object_data.MJD.values[0]
    #Update dataframe
    object_data['dmagt'] = np.where((object_data['MJD'] == E), dmagt, object_data['dmagt'])
    object_data['dmagt_err'] = np.where((object_data['MJD'] == E), dmagt_err, object_data['dmagt_err'])
    object_data['rms_t'] = np.where((object_data['MJD'] == E), rms_t, object_data['rms_t'])
    object_data['rms_t_err'] = np.where((object_data['MJD'] == E), rms_t_err, object_data['rms_t_err'])

    #Updated dataframe
    Star_Data_T = Locate_Star_Scope(object_data, Telescope)
    Star_Data_mjd = Locate_Star_Epoch(Star_Data_T, E)
    
    #Update total error of datapoint with new t_rms
    err_tot = (Star_Data_mjd.err_aper**2 + Star_Data_mjd.rms_sc**2 + Star_Data_mjd.rms_star**2 + rms_t**2)**0.5
    
    #Update dataframe
    object_data.loc[err_tot.index, 'err_tot'] = err_tot
    
    return object_data

def process_dataframe_parallel(df, loop, safe, Telescope, num_workers=4):
    # Group the DataFrame by the 'MJD' column
    grouped_data = df.groupby('MJD')

    # Prepare the input for parallel processing
    object_data_list = [object_subset for _, object_subset in grouped_data]

    # Use ThreadPoolExecutor to parallelize the processing
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(process_object_data, object_subset, loop, safe, Telescope): object_subset for object_subset in object_data_list}

    # Collect the results and concatenate into a single DataFrame
    results = [future.result() for future in as_completed(futures)]
    result_df = pd.concat(results, ignore_index=True)

    result_df = pd.concat(results, ignore_index=True)
    return result_df


def process_dataframe_parallel2(df, loop, safe, Star_IDs, num_workers=4):
    # Group the DataFrame by the 'telid' column
    grouped_data = df.groupby('telid')

    # Prepare the input for parallel processing
    object_data_list = [object_subset for _, object_subset in grouped_data]

    # Use ThreadPoolExecutor to parallelize the processing
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(dmagt, object_subset, loop, safe, Star_IDs): object_subset for object_subset in object_data_list}

    # Collect the results and concatenate into a single DataFrame
    results = [future.result() for future in as_completed(futures)]
    result_df = pd.concat(results, ignore_index=True)

    return result_df

def process_dataframe_parallel3(df, loop, safe, Star_IDs, num_workers=4):
    # Group the DataFrame by the 'telid' column
    grouped_data = df.groupby('telid')

    # Prepare the input for parallel processing
    object_data_list = [object_subset for _, object_subset in grouped_data]

    # Use ThreadPoolExecutor to parallelize the processing
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(dmags, object_subset, loop, safe, Star_IDs): object_subset for object_subset in object_data_list}

    # Collect the results and concatenate into a single DataFrame
    results = [future.result() for future in as_completed(futures)]
    result_df = pd.concat(results, ignore_index=True)

    return result_df

def process_dataframe_parallel4(df, loop, safe, num_workers=4):
    # Group the DataFrame by the 'telid' column
    grouped_data = df.groupby('id_apass')

    # Prepare the input for parallel processing
    object_data_list = [object_subset for _, object_subset in grouped_data]

    # Use ThreadPoolExecutor to parallelize the processing
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(star_mags, object_subset, loop, safe): object_subset for object_subset in object_data_list}

    # Collect the results and concatenate into a single DataFrame
    results = [future.result() for future in as_completed(futures)]
    result_df = pd.concat(results, ignore_index=True)

    return result_df

#----------------
#Absolute photometric calibration functions
def Phot_Cal(DF,filt,catalogue='apass_updated.csv', AGN_ID=AGN_ID):
    df = {}
    apass = pd.read_csv(catalogue)
    for i,idx in enumerate(np.unique(DF['id_apass'])[:]):
        #print(idx)
        df_temp = Locate_Star_ID(DF,idx)
        #Apply correction parameters
        mags = df_temp['mag_aper'].values 
        #plt.plot(mags,'ko')
        errs = (df_temp['err_aper'].values**2 + df_temp['rms_sc'].values**2 + df_temp['rms_t'].values**2)**0.5
        mag,err = Generic_Optimal(mags,errs)

        ss = apass['id'] == idx
        #print(ss.sum())
        df[i] = {
                'id_apass':idx,
                'mag':df_temp['mag_aper'].values[0],
                'err':err,
                'mag_lco': apass[filt[:1]+'_lco'].values[ss][0],
                'err_lco': apass[filt[:1]+'err_lco'].values[ss][0]
                }
        #print(mag,df_temp['mag_aper'].values[0],apass[filt[:1]+'_lco'].values[ss])

    zps = pd.DataFrame.from_dict(df,"index")
    
    from astropy.stats import sigma_clip

    diffs =  zps['mag_lco'].values - zps['mag'].values
    diffs_err = np.sqrt(zps['err'].values**2 + zps['err_lco'].values**2)
    pp = sigma_clip(diffs, sigma=3)

    zp,aa,zp_err,ee = arx.arx(diffs[~pp.mask],diffs_err[~pp.mask])
    
    return zp, zp_err

## Files and data

In [None]:
# ------- Import star data in pickle format -------
lco = pd.read_pickle('lco_latest_stan.pkl')

# -------------- Obs info --------------
#AGN Name
obj_name = 'Mrk_841'

#FILTERS = pd.unique(lco['Filter'])
FILTERS = ['gp', 'V', 'rp', 'ip', 'zs'] #in order
print('FILTERS:', FILTERS)

TEL = list(pd.unique(lco['telid']))
print('SCOPES:', TEL)

AGN_ID = 1017

## Cleaning

In [None]:
# -------------- Clean up data file --------------
lco2=Clean(lco)

## Intercalibration

In [None]:
#Intercalibrate
DF_gp, TR_gp = Corr(lco2, 'gp', MAX_LOOPS = 100, bad_IDs = [], safe = 0.6, frac = 0.5, TEL = TEL, AGN_ID = AGN_ID, Star_Lim = 100)
DF_V, TR_V = Corr(lco2, 'V', MAX_LOOPS = 100, bad_IDs = [], safe = 0.6, frac = 0.5, TEL = TEL, AGN_ID = AGN_ID, Star_Lim = 100)
DF_rp, TR_rp = Corr(lco2, 'rp', MAX_LOOPS = 100, bad_IDs = [], safe = 0.6, frac = 0.5, TEL = TEL, AGN_ID = AGN_ID, Star_Lim = 100)
DF_ip, TR_ip = Corr(lco2, 'ip', MAX_LOOPS = 100, bad_IDs = [], safe = 0.6, frac = 0.5, TEL = TEL, AGN_ID = AGN_ID, Star_Lim = 100)
DF_zs, TR_zs = Corr(lco2, 'zs', MAX_LOOPS = 100, bad_IDs = [], safe = 0.6, frac = 0.5, TEL = TEL, AGN_ID = AGN_ID, Star_Lim = 100)

#Combine all into one dataframe
All_DF = [DF_gp, DF_V, DF_rp, DF_ip, DF_zs]
CALIBRATED_STARS = pd.concat(All_DF, ignore_index = True)
#Save as csv file
CALIBRATED_STARS.to_csv(obj_name +'_Calibrated_Stars'+'.csv')

In [None]:
#Load csv file
CALIBRATED_STARS = pd.read_csv(obj_name +'_Calibrated_Stars'+'.csv')

DF_gp = Locate_Star_Filter(CALIBRATED_STARS, 'gp')
DF_V = Locate_Star_Filter(CALIBRATED_STARS, 'V')
DF_rp = Locate_Star_Filter(CALIBRATED_STARS, 'rp')
DF_ip = Locate_Star_Filter(CALIBRATED_STARS, 'ip')
DF_zs = Locate_Star_Filter(CALIBRATED_STARS, 'zs')

## Zeropoint calibration

In [None]:
ZP_gp = Phot_Cal(DF_gp, 'gp', catalogue='apass_updated.csv')
ZP_V = Phot_Cal(DF_V, 'V', catalogue='apass_updated.csv')
ZP_rp = Phot_Cal(DF_rp, 'rp', catalogue='apass_updated.csv')
ZP_ip = Phot_Cal(DF_ip, 'ip', catalogue='apass_updated.csv')
ZP_zs = Phot_Cal(DF_zs, 'zs', catalogue='apass_updated.csv')

## AGN light curves

In [None]:
AGN_DF_gp = AGN_LC(lco2, DF_gp, 'gp', err_th = 0.03, Plot = True, Rem_out = True, zp=ZP_gp)

In [None]:
AGN_DF_V = AGN_LC(lco2, DF_V, 'V', err_th = 0.03, Plot = True, Rem_out = True, zp=ZP_V)

In [None]:
AGN_DF_rp = AGN_LC(lco2, DF_rp, 'rp', err_th = 0.03, Plot = True, Rem_out = True, zp=ZP_rp)

In [None]:
AGN_DF_ip = AGN_LC(lco2, DF_ip, 'ip', err_th = 0.03, Plot = True, Rem_out = True, zp=ZP_ip)

In [None]:
AGN_DF_zs = AGN_LC(lco2, DF_zs, 'zs', err_th = 0.03, Plot = True, Rem_out = True, zp=ZP_zs)

In [None]:
#Combine all AGN data into one dataframe
AGN = pd.concat([AGN_DF_gp, AGN_DF_V, AGN_DF_rp, AGN_DF_ip, AGN_DF_zs], ignore_index=True)
#Saving to file
AGN.to_csv(obj_name +'_AGN'+'.csv')

In [None]:
#Creating plots
for F, dff in zip(FILTERS, [DF_gp, DF_V, DF_rp, DF_ip, DF_zs])
    AGN_LC2(lco2, dff, fr'{F}', err_th = 0.03, Plot = True, Rem_out = False, zp=ZP_gp)

# PyCCF

In [None]:
#Load csv file
AGN_STARS = pd.read_csv(obj_name +'_AGN'+'.csv')
UVOT = pd.read_csv('Mrk841_UVOT.csv')

#Combining files
DF_AGN_W2 = UVOT.loc[UVOT['FILTER'] == "UVW2"] 
DF_AGN_W2 = DF_AGN_W2.drop(columns=["OBJECT", "MJD_ERR", "HJD", "OBSDATE", "OBSTIME", "OBSID", "EXTENSION", "FLUX_AA_5", "FLUX_AA_5_ERR_SIG5"])
DF_AGN_W2 = DF_AGN_W2.rename(columns={"FILTER":"Filter", "FLUX_HZ_5":"flux", "FLUX_HZ_5_ERR_SIG7P5":"err"})

In [None]:
#Conversions
def Jyconvert(mag):
    return 10**(23-0.4*(mag+48.6)) * 1e3
def errorconvert(err, flux):
    return flux*err*np.log(10)/2.5

In [None]:
#Converting LCO database units to Swift units
def converter(database):
    columns = ['MJD', 'Filter', 'flux', 'err']
    dataframe = pd.DataFrame(columns=columns)
    mjd1, flux1, err1, filter1 =  database['MJD'].to_numpy(), database['mag'].to_numpy(), database['err'].to_numpy(), database['Filter'].to_numpy()
    
    fluxjy = Jyconvert(flux1)
    # fluxjy = flux1
    errjy = errorconvert(err1, fluxjy)
    # errjy = err1

    temp = {'MJD': mjd1, 'Filter': filter1, 'flux': fluxjy, 'err': errjy}
    temp_DF = pd.DataFrame(temp)
    dataframe = pd.concat([dataframe, temp_DF], ignore_index=True)
    return dataframe

NEW_STARS = converter(AGN_STARS)
ALLSTARS = pd.concat([NEW_STARS, DF_AGN_W2], ignore_index=True)
#Saving to file
ALLSTARS.to_csv(obj_name +'_ALLSTARS'+'.csv')

In [None]:
#Converting to format usable in PyROA
def pyroadf2(database):
    FILTERS = ['gp', 'V', 'rp', 'ip', 'zs', 'UVW2'] #in order
    DF_LIST = []
    columns = ['MJD', 'flux', 'err']
    for i in range(len(FILTERS)):
        dataframe = pd.DataFrame(columns=columns)
        
        f1 = FILTERS[i]
        # print(f1)
        lc1 = Locate_Star_Filter(database, str(f1))
        # display(datum)
        mjd1, flux1, err1 =  lc1['MJD'].to_numpy(), lc1['flux'].to_numpy(), lc1['err'].to_numpy()
        # fluxjy = Jyconvert(flux1)
        # fluxjy = flux1
        # errjy = errorconvert(err1, fluxjy)
        # errjy = err1
        
        temp = {'MJD': mjd1, 'flux': flux1, 'err': err1}
        temp_DF = pd.DataFrame(temp)
        dataframe = pd.concat([dataframe, temp_DF], ignore_index=True)
        
        DF_LIST.append(dataframe)
    return DF_LIST

DF2_LIST = pyroadf2(ALLSTARS)

for i in range(len(DF2_LIST)):
    FILTERS = ['gp', 'V', 'rp', 'ip', 'zs', 'UVW2'] #in order
    df = DF2_LIST[i]
    Filter = FILTERS[i]
    df.to_csv(fr"/home/jovyan/PyROAF/{obj_name}_{FILTERS[i]}.dat", header=None, index=None, sep=" ")
    df.to_csv(fr"/home/jovyan/PyROAF/{obj_name}_{FILTERS[i]}1.dat", header=None, index=None, sep=" ")

## Using PyCCF

In [None]:
# adapted from sample_runcode.py script accessible at http://ascl.net/code/v/1868
# Reference: Peterson et al. 1998 (https://arxiv.org/abs/astro-ph/9802103)
def pyccfer_flux(database):
    
    start_time = time.time()
    columns = ['Centroid', 'Centroid upper error', 'Centroid lower error', 'Peak', 'Peak upper error', 'Peak lower error']
    df = pd.DataFrame(columns=columns)
    database = database.sort_values(by=['MJD'])
    FILTERS = ["gp", "V", "rp", "ip", "zs", "UVW2"]
    
    for j in range(2):
        
        print(j)
        
        if j == 0:
            ref_date1 = 57478
            ref_date2 = 57664
        
        if j == 1:
            ref_date1 = 57700
            ref_date2 = 10**6
        
        databased = database.loc[(database['MJD'] > ref_date1) & (database['MJD'] < ref_date2)]
        
        for i in range(len(FILTERS)):
            f1 = FILTERS[0]
            f2 = FILTERS[i]

            print(fr'Compared filters: {f1} and {f2}')

            lc1 = Locate_Star_Filter(databased, str(f1))
            lc2 = Locate_Star_Filter(databased, str(f2))

            mjd1, flux1, err1 =  lc1['MJD'].to_numpy(), lc1['flux'].to_numpy(), lc1['err'].to_numpy()
            mjd2, flux2, err2 =  lc2['MJD'].to_numpy(), lc2['flux'].to_numpy(), lc2['err'].to_numpy()


            #########################################
            ##Set Interpolation settings, user-specified
            #########################################
            lag_range = [-20, 20]  #Time lag range to consider in the CCF (days). Must be small enough that there is some overlap between light curves at that shift (i.e., if the light curves span 80 days, these values must be less than 80 days)
            interp = 0.5 #Interpolation time step (days). Must be less than the average cadence of the observations, but too small will introduce noise.
            nsim = 10000  #Number of Monte Carlo iterations for calculation of uncertainties
            mcmode = 0  #Do both FR/RSS sampling (1 = RSS only, 2 = FR only) 
            sigmode = 0.2  #Choose the threshold for considering a measurement "significant". sigmode = 0.2 will consider all CCFs with r_max <= 0.2 as "failed". See code for different sigmodes.

            ##########################################
            #Calculate lag with python CCF program
            ##########################################
            tlag_peak, status_peak, tlag_centroid, status_centroid, ccf_pack, max_rval, status_rval, pval = myccf.peakcent(mjd1, flux1, mjd2, flux2, lag_range[0], lag_range[1], interp)
            tlags_peak, tlags_centroid, nsuccess_peak, nfail_peak, nsuccess_centroid, nfail_centroid, max_rvals, nfail_rvals, pvals = myccf.xcor_mc(mjd1, flux1, abs(err1), mjd2, flux2, abs(err2), lag_range[0], lag_range[1], interp, nsim = nsim, mcmode=mcmode, sigmode = 0.2)

            lag = ccf_pack[1]
            r = ccf_pack[0]

            perclim = 84.1344746    

            ###Calculate the best peak and centroid and their uncertainties using the median of the
            ##distributions. 
            centau = stats.scoreatpercentile(tlags_centroid, 50)
            centau_uperr = (stats.scoreatpercentile(tlags_centroid, perclim))-centau
            centau_loerr = centau-(stats.scoreatpercentile(tlags_centroid, (100.-perclim)))
            print('Centroid, error: %10.3f  (+%10.3f -%10.3f)'%(centau, centau_uperr, centau_loerr))

            peaktau = stats.scoreatpercentile(tlags_peak, 50)
            peaktau_uperr = (stats.scoreatpercentile(tlags_peak, perclim))-centau
            peaktau_loerr = centau-(stats.scoreatpercentile(tlags_peak, (100.-perclim)))
            print('Peak, errors: %10.3f  (+%10.3f -%10.3f)'%(peaktau, peaktau_uperr, peaktau_loerr))

            temp = [{'Centroid':centau, 'Centroid upper error':centau_uperr, 'Centroid lower error':centau_loerr, 'Peak':peaktau, 'Peak upper error':peaktau_uperr, 'Peak lower error':peaktau_loerr, 'CCF':lag, 'Correlation': r, 'Centroid hist': tlags_centroid, 'Peak hist': tlags_peak}]

            temp_df = pd.DataFrame(temp)

            df = pd.concat([df, temp_df], ignore_index=True)

            ##########################################
            #Plot the Light curves, CCF, CCCD, and CCPD
            ##########################################

            fig = plt.figure()
            fig.subplots_adjust(hspace=0.2, wspace = 0.1)

            #Plot lightcurves
            ax1 = fig.add_subplot(3, 1, 1)
            ax1.errorbar(mjd1, flux1, yerr = err1, marker = '.', linestyle = ':', color = 'k', label = 'LC 1 (Continuum)')
            ax1_2 = fig.add_subplot(3, 1, 2, sharex = ax1)
            ax1_2.errorbar(mjd2, flux2, yerr = err2, marker = '.', linestyle = ':', color = 'k', label = 'LC 2 (Emission Line)')

            # ax1.text(0.025, 0.825, lc1, fontsize = 15, transform = ax1.transAxes)
            # ax1_2.text(0.025, 0.825, lc2, fontsize = 15, transform = ax1_2.transAxes)
            ax1.set_ylabel('LC 1 Flux')
            ax1_2.set_ylabel('LC 2 Flux')
            ax1_2.set_xlabel('MJD')

            #Plot CCF Information
            xmin, xmax = -99, 99
            ax2 = fig.add_subplot(3, 3, 7)
            ax2.set_ylabel('CCF r')
            ax2.text(0.2, 0.85, 'CCF ', horizontalalignment = 'center', verticalalignment = 'center', transform = ax2.transAxes, fontsize = 16)
            ax2.set_xlim(xmin, xmax)
            ax2.set_ylim(-1.0, 1.0)
            ax2.plot(lag, r, color = 'k')

            ax3 = fig.add_subplot(3, 3, 8, sharex = ax2)
            ax3.set_xlim(xmin, xmax)
            ax3.axes.get_yaxis().set_ticks([])
            ax3.set_xlabel('Centroid Lag: %5.1f (+%5.1f -%5.1f) days'%(centau, centau_uperr, centau_loerr), fontsize = 15) 
            ax3.text(0.2, 0.85, 'CCCD ', horizontalalignment = 'center', verticalalignment = 'center', transform = ax3.transAxes, fontsize = 16)
            n, bins, etc = ax3.hist(tlags_centroid, bins = 50, color = 'b')
            print(len(tlags_centroid))

            ax4 = fig.add_subplot(3, 3, 9, sharex = ax2)
            ax4.set_ylabel('N')
            ax4.yaxis.tick_right()
            ax4.yaxis.set_label_position('right') 
            #ax4.set_xlabel('Lag (days)')
            ax4.set_xlim(xmin, xmax)
            ax4.text(0.2, 0.85, 'CCPD ', horizontalalignment = 'center', verticalalignment = 'center', transform = ax4.transAxes, fontsize = 16)
            ax4.hist(tlags_peak, bins = bins, color = 'b')

            plt.show()
        print('Total run time:', time.time() - start_time)
    return df

In [None]:
FCCF_DF = pyccfer_flux(ALLSTARS)

In [None]:

#File organisation; adding year, waveband error, and waveband wavelength
DF_FCCF = FCCF_DF
Fwavelength = np.array([4770, 5510, 6231, 7625, 9134, 1894, 4770, 5510, 6231, 7625, 9134, 1894])
FFerror = np.array([1262.68/2, 840/2, 1149.52/2, 1238.95/2, 994.39/2, 584.89/2, 1262.68/2, 840/2, 1149.52/2, 1238.95/2, 994.39/2, 584.89/2])
Fyear = np.array([2016, 2016, 2016, 2016, 2016, 2016, 2017, 2017, 2017, 2017, 2017, 2017])
Fcentroid = DF_FCCF['Centroid']
FFILTERS = ["gp", "V", "rp", "ip", "zs", "UVW2"]
FFilters = np.tile(FFILTERS, 2)
Ftemp = {'Wavelength': Fwavelength, 'Year':Fyear, 'Filter':FFilters, 'Centroid':Fcentroid, 'Filter error':FFerror}
Ftemp_DF = pd.DataFrame(Ftemp, index=DF_FCCF.index)
DF_FCCF = pd.merge(DF_FCCF, Ftemp_DF, on='Centroid')

In [None]:
DF_FCCF.to_pickle(obj_name+'_FPyCCF'+'.pkl')
DF_FCCF = pd.read_pickle(obj_name+'_FPyCCF'+'.pkl')

# PyROA

## Importing files, data, making fits

In [None]:
#fit was completed six times, each with a different Delta and saved to different folders

In [None]:
obj_name = 'Mrk_841'
FILTERS = ['UVW2','gp1', 'gp', 'V', 'rp', 'ip', 'zs']
datadir = "/home/jovyan/PyROAF/"
filters = FILTERS
init_tau = [-3, 0, 0.7, 3.0, 3.5, 4.5]
priors = [[0.5, 2.0],[0.5, 2.0], [-50.0, 50.0], [0.5, 10.0], [0.0, 10.0]]
fit = PyROA.Fit(datadir, obj_name, filters, priors, add_var=True, init_tau = init_tau, Nsamples=70000, Nburnin=15000, use_backend=True, delay_ref='gp1') 
fit = PyROA.Fit(datadir, obj_name, filters, priors, add_var=True,Nsamples=30000, Nburnin=0,use_backend=True, resume_progress=True, delay_ref='gp1')
outputdir = "/home/jovyan/PyROAF/"

In [None]:
obj_name = 'Mrk_841'
FILTERS = ['UVW2','gp1', 'gp', 'V', 'rp', 'ip', 'zs']
datadir = "/home/jovyan/PyROA2/"
filters = FILTERS
init_tau = [-3, 0, 0.7, 3.0, 3.5, 4.5]
priors = [[0.5, 2.0],[0.5, 2.0], [-50.0, 50.0], [1.99, 2.01], [0.0, 10.0]]
fit = PyROA.Fit(datadir, obj_name, filters, priors, add_var=True, init_tau = init_tau, Nsamples=70000, Nburnin=15000, use_backend=True, delay_ref='gp1') 
fit = PyROA.Fit(datadir, obj_name, filters, priors, add_var=True,Nsamples=30000, Nburnin=0,use_backend=True, resume_progress=True, delay_ref='gp1')
outputdir = "/home/jovyan/PyROA2/"

In [None]:
obj_name = 'Mrk_841'
FILTERS = ['UVW2','gp1', 'gp', 'V', 'rp', 'ip', 'zs']
datadir = "/home/jovyan/PyROA5/"
filters = FILTERS
init_tau = [-3, 0, 0.7, 3.0, 3.5, 4.5]
priors = [[0.5, 2.0],[0.5, 2.0], [-50.0, 50.0], [4.99, 5.01], [0.0, 10.0]]
fit = PyROA.Fit(datadir, obj_name, filters, priors, add_var=True, init_tau = init_tau, Nsamples=70000, Nburnin=15000, use_backend=True, delay_ref='gp1') 
fit = PyROA.Fit(datadir, obj_name, filters, priors, add_var=True,Nsamples=30000, Nburnin=0,use_backend=True, resume_progress=True, delay_ref='gp1')
outputdir = "/home/jovyan/PyROA5/"

In [None]:
obj_name = 'Mrk_841'
FILTERS = ['UVW2','gp1', 'gp', 'V', 'rp', 'ip', 'zs']
datadir = "/home/jovyan/PyROA10/"
filters = FILTERS
init_tau = [-3, 0, 0.7, 3.0, 3.5, 4.5]
priors = [[0.5, 2.0],[0.5, 2.0], [-50.0, 50.0], [9.99, 10.01], [0.0, 10.0]]
fit = PyROA.Fit(datadir, obj_name, filters, priors, add_var=True, init_tau = init_tau, Nsamples=70000, Nburnin=15000, use_backend=True, delay_ref='gp1') 
fit = PyROA.Fit(datadir, obj_name, filters, priors, add_var=True,Nsamples=30000, Nburnin=0,use_backend=True, resume_progress=True, delay_ref='gp1')
outputdir = "/home/jovyan/PyROA10/"

In [None]:
obj_name = 'Mrk_841'
FILTERS = ['UVW2','gp1', 'gp', 'V', 'rp', 'ip', 'zs']
datadir = "/home/jovyan/PyROAD20/"
filters = FILTERS
init_tau = [-3, 0, 0.7, 3.0, 3.5, 4.5]
priors = [[0.5, 2.0],[0.5, 2.0], [-50.0, 50.0], [19.99, 20.01], [0.0, 10.0]]
fit = PyROA.Fit(datadir, obj_name, filters, priors, add_var=True, init_tau = init_tau, Nsamples=70000, Nburnin=15000, use_backend=True, delay_ref='gp1') 
fit = PyROA.Fit(datadir, obj_name, filters, priors, add_var=True,Nsamples=30000, Nburnin=0,use_backend=True, resume_progress=True, delay_ref='gp1')
outputdir = "/home/jovyan/PyROAD20/"

In [None]:
obj_name = 'Mrk_841'
FILTERS = ['UVW2','gp1', 'gp', 'V', 'rp', 'ip', 'zs']
datadir = "/home/jovyan/PyROA40/"
filters = FILTERS
init_tau = [-3, 0, 0.7, 3.0, 3.5, 4.5]
priors = [[0.5, 2.0],[0.5, 2.0], [-50.0, 50.0], [39.99, 40.01], [0.0, 10.0]]
fit = PyROA.Fit(datadir, obj_name, filters, priors, add_var=True, init_tau = init_tau, Nsamples=70000, Nburnin=15000, use_backend=True, delay_ref='gp1') 
fit = PyROA.Fit(datadir, obj_name, filters, priors, add_var=True,Nsamples=30000, Nburnin=0,use_backend=True, resume_progress=True, delay_ref='gp1')
outputdir = "/home/jovyan/PyROA40/"

# Plots

## Obtaining times of Swift XRT observations

In [None]:
names3 = ['time','time_pos','time_neg','count', 'count_pos', 'count_neg']
curvedf = pd.read_table('curve_plot.qdp',skiprows=3,names=names3, delimiter='\t')

time = curvedf['time'].values
t0 = 189304741.8006
time = time+t0

def convert_mjd(time):
    return 51910.00015488 + time/86400

mjd = convert_mjd(time)

swiftxrt = pd.DataFrame({'time':mjd, 'count':curvedf['count'].values})
swiftxrt.to_csv('swiftxrt.csv')

# Swift UVOT light curves

In [None]:
UVOT = pd.read_csv('Mrk841_UVOT.csv')
FILTERS = ['UVW2', 'UVM2', 'UVW1', 'U', 'B', 'V']
filtername=FILTERS
def uvot_curve(dataframe):
    
    def helper(l):
        sort_order = {k:v for k,v in zip(l, range(len(l)))}
        return lambda s: s.map(lambda x: sort_order[x])

    dataframed = dataframe.sort_values('FILTER', key=helper(FILTERS))
    dataframed = dataframed.sort_values(by=['MJD'])
    
    fig, axs = plt.subplots(len(FILTERS), figsize=(15,12), sharex='col')#, layout='constrained')
    axs[0].set_title(fr"Mrk 841 light curves")
    fig.subplots_adjust(wspace=0, hspace=0)
    xmin, xmax = -11, 11

    colors = ["lightseagreen", "darkgreen", "#8eab12", "darkorange", "crimson", "indigo", "lightseagreen", "darkgreen", "greenyellow", "darkorange", "crimson", "indigo"]

    fig.supylabel(r"$F_\nu$ (erg/$\mathrm{cm}^2$/s/Hz)", fontsize=20, position=(0.05,0.5))
    for i in range(len(FILTERS)):
        f1 = FILTERS[i]
        lc1 = dataframed.loc[(dataframed['FILTER'] == str(f1))]
        mjd1, mjderr1, flux1, err1 =  lc1['MJD'].to_numpy(), lc1['MJD_ERR'].to_numpy(), lc1['FLUX_HZ_5'].to_numpy(), lc1['FLUX_HZ_5_ERR_SIG7P5'].to_numpy()

        f2=filtername[i]

        axs[i].errorbar(mjd1, flux1, yerr = err1, xerr=mjderr1, marker = '.', linestyle = ':', color = colors[i])
        axs[i].set_ylabel(fr"{f2}")
        axs[i].invert_yaxis()
        
    axs[-1].set_xlabel("Modified Julian Date") 
    axs[-1].set_xlabel("Lag (days)")
    plt.savefig(fr"UVOT_curve.png", dpi=300)
uvot_curve(UVOT)

## PyCCF light curves and lags

In [None]:
#importing data
ALLSTARS = pd.read_csv(obj_name +'_ALLSTARS'+'.csv')
swiftxrt = pd.read_csv('swiftxrt.csv')

In [3]:
filtername=[r"UVW2", r"$g'$", r"$V$", r"$r'$", r"$i'$", r"$z_s$"]

def totalspec2(dataframe, database, time):
    database = database.sort_values(by=['MJD'])
   
    for j in range(2):
        if j == 0:
            dataframed = dataframe.loc[(dataframe['Year'] == 2016)]
            year = 2016
            
            ref_date1 = 57495
            ref_date2 = 57641
            
            
        else:
            dataframed = dataframe.loc[(dataframe['Year'] == 2017)]
            year = 2017
            
            ref_date1 = 57751
            ref_date2 = 57988
        
        databased = database.loc[(database['MJD'] > ref_date1) & (database['MJD'] < ref_date2)]
        
        
        def helper(l):
            sort_order = {k:v for k,v in zip(l, range(len(l)))}
            return lambda s: s.map(lambda x: sort_order[x])
        
        databased = databased.sort_values('Filter', key=helper(FILTERS))
        dataframed = dataframed.sort_values('Filter', key=helper(FILTERS))
        
        databased = databased.sort_values(by=['MJD'])
        
        databased = databased.reset_index(drop=True)
        dataframed = dataframed.reset_index(drop=True)
        
        fig, axs = plt.subplots(len(FILTERS), 2, figsize=(15,12), gridspec_kw={'width_ratios': [3, 1]}, sharex='col')#, layout='constrained')
        axs[0,0].set_title(fr"Mrk 841 light curves")
        axs[0,1].set_title("CCF")
        
        fig.subplots_adjust(wspace=0, hspace=0)
        
        xmin, xmax = -11, 11
        
        colors = ["lightseagreen", "darkgreen", "#8eab12", "darkorange", "crimson", "indigo", "lightseagreen", "darkgreen", "greenyellow", "darkorange", "crimson", "indigo"]
    
        fig.supylabel(fr"$F_\nu$ (mJy)", fontsize=20, position=(0.05,0.5))
        for i in range(len(FILTERS)):
            f1 = FILTERS[i]
            # print(f1)
            lc1 = Locate_Star_Filter(databased, str(f1))
            # display(lc1)
            datum = dataframed.loc[(dataframed['Filter'] == str(f1))]
            # display(datum)
            mjd1, flux1, err1 =  lc1['MJD'].to_numpy(), lc1['flux'].to_numpy(), lc1['err'].to_numpy()
            
            f2=filtername[i]
            
            nbins = 15
            lag = datum['CCF'][i]
            r = datum['Correlation'][i]
            tlags_centroid = datum['Centroid hist'][i]
                
                
            axs[i,0].errorbar(mjd1, flux1, yerr = err1, marker = '.', linestyle = ':', color = colors[i])
            axs[i,0].set_ylabel(fr"{f2}")
            axs[i,0].invert_yaxis()
            axs[i,0].vlines(time['time'], ymin=-20, ymax=20, zorder=0, color = 'lightgrey')
            axs[i,0].set_xlim(min(mjd1)*0.9999, max(mjd1)*1.0001)
            axs[i,0].set_ylim(min(flux1)*0.9, max(flux1)*1.1)
           
            axs[i,1].plot(lag, r, color='k')
            axs[i,1].set_xlim(xmin, xmax)
            axs[i, 1].set_ylim(0,1)
            axs[i,1].set_yticks([])
            ax2 = axs[i,1].twinx()
            ax2.hist(tlags_centroid, bins = nbins, color = colors[i])
            ax2.axvline(x = np.percentile(tlags_centroid, [16, 50, 84])[1], color="black")
            ax2.axvline(x = np.percentile(tlags_centroid, [16, 50, 84])[0], color="black", ls="--")
            ax2.axvline(x = np.percentile(tlags_centroid, [16, 50, 84])[2], color="black", ls="--")
            ax2.set_yticks([])
            ax3 = axs[i,1].twinx()
            ax3.set_yticks([])
            ax3.axvline(0, color="dodgerblue")
            
            if i == 0:
                ticker = [0,0.25, 0.5, 0.75, 1]
                labels = ["0.00, 1", "0.25", "0.50", "0.75", "1.00"]
                axs[i,1].yaxis.tick_right()
                axs[i,1].set_yticks(ticker)
                axs[i,1].set_yticklabels(labels)
            elif i == 5:
                ticker = [0,0.25, 0.5, 0.75]
                labels = ["0.00", "0.25", "0.50", "0.75"]
                axs[i,1].yaxis.tick_right()
                axs[i,1].set_yticks(ticker)
                axs[i,1].set_yticklabels(labels)
            else:
                ticker = [0,0.25, 0.5, 0.75]
                labels = ["0.00, 1", "0.25", "0.50", "0.75"]
                axs[i,1].yaxis.tick_right()
                axs[i,1].set_yticks(ticker)
                axs[i,1].set_yticklabels(labels)
            
            
        axs[-1,0].set_xlabel("Modified Julian Date") 
        axs[-1,1].set_xlabel("Lag (days)")
        plt.savefig(fr"totalspec{year}.png", dpi=300)

SyntaxError: invalid syntax (368577995.py, line 108)

In [None]:
totalspec2(DF_FCCF, ALLSTARS, swiftxrt)

## PyROA light curves and lags, flux-flux, and SED

In [None]:
#only fits from Delta=2 were chosen to be plotted

waves = [1894, 4770, 5510, 6231, 7625, 9134]
redshift = 0.0364
band_colors=['#0652DD','#1289A7','#006266','#A3CB38','orange','#EE5A24','brown']
delay_ref = 'gp1'
outputdir = "/home/jovyan/PyROA2/"
datadir = "/home/jovyan/PyROA2/"
gal_ref = 'UVW2'
burnin=150000
# following functions are adapted from Utils.py accessible at https://github.com/Alymantara/PyROA
# Reference: Donnan et al. 2021 (https://ui.adsabs.harvard.edu/abs/2021arXiv210712318D/abstract)

import scipy.interpolate as interpolate
import matplotlib.pyplot as plt
from scipy.stats import median_abs_deviation as mad
FILTERS = ["UVW2", "gp1", "gp", "V", "rp", "ip", "zs"]
filters = FILTERS
waves = [1894, 4770, 5510, 6231, 7625, 9134]#[3580,4392,4770,5468,6215,7545,8700]
redshift = 0.0364
band_colors=colors = ["lightseagreen", "darkgreen", "darkgreen","#8eab12", "darkorange", "crimson", "indigo", "brown"]
delay_ref = 'gp1'
outputdir = "/home/jovyan/PyROA2/"
datadir = "/home/jovyan/PyROA2/"
gal_ref = 'UVW2'
obj_name = 'Mrk_841'
swiftxrt = pd.read_csv('swiftxrt.csv')
burnin=150000

def Lightcurves(objName, filters, delay_ref, 
                lc_file="Lightcurve_models.obj",
                samples_file='samples_flat.obj',
                slow_comp_file='Slow_Comps.obj',
                outputdir = './', datadir='./',
                burnin=0, band_colors = None,
                limits=None, grid=False, grid_step=5.0,
                show_delay_ref=False, ylab = None,
                filter_labels = None, savefig=True, figname=None,
                include_slow_comp=False,slow_comp_delta=30.0, time=None
                ):

    if outputdir[-1] != '/': outputdir += '/'

    if ylab ==None: ylab = r"F$_{\nu}$"+"\nmJy"
    if filter_labels == None: filter_labels = filters

    ss = np.where(np.array(filters) == delay_ref)[0][0]
    file = open(outputdir+samples_file,'rb')
    samples_flat = pickle.load(file)
    samples_flat = samples_flat[burnin:,:]
    file = open(outputdir+lc_file,'rb')
    models = pickle.load(file)

    if include_slow_comp:
        file = open(outputdir+slow_comp_file,'rb')
        slow_comps = pickle.load(file)


    #Split samples into chunks, 4 per lightcurve i.e A, B, tau, sig
    chunk_size=4
    transpose_samples = np.transpose(samples_flat)
    #Insert zero where tau_0 would be 
    transpose_samples = np.insert(transpose_samples, [ss*4+2], np.array([0.0]*len(transpose_samples[1])), axis=0)
    samples_chunks = [transpose_samples[i:i + chunk_size] for i in range(0, len(transpose_samples), chunk_size)] 



    fig = plt.figure(figsize=(20,len(filters)*3.5))
    fig.tight_layout()
    corro = 1
    if show_delay_ref: corro = 0
    gs = fig.add_gridspec(len(filters)-corro, 2, hspace=0, wspace=0, width_ratios=[5, 1])
    axs= gs.subplots(sharex='col')

    if band_colors == None:
        band_colors = ['k']*len(filters)

    #Loop over lightcurves

    data=[]
    ko = 0

    if limits !=None:
        xmin=limits[0]#59337
        xmax=limits[1]#59621

    for i in range(len(filters)):
        #Read in data
        file = datadir + objName+"_" + str(filters[i]) + ".dat"
        data.append(np.loadtxt(file))
        mjd = data[i][:,0]
        flux = data[i][:,1]
        err = data[i][:,2]    

        if (i == 0) & (limits == None):
            xmin = np.nanmin(mjd)-10
            xmax = np.nanmax(mjd)+10
        #Add extra variance
        B = np.percentile(samples_chunks[i][1], 50)
        sig = np.percentile(samples_chunks[i][3], 50)
        err = np.sqrt(err**2 + sig**2)
        fig.supylabel(fr"$F_\nu$ (mJy)", position=(50, 50))
        
        if ((filters[i] != delay_ref) ):

            gs00 = gridspec.GridSpecFromSubplotSpec(6, 1, subplot_spec=axs[i-ko][0],hspace=0)
            ax1 = fig.add_subplot(gs00[:, :])
            ax1.set_ylim(np.median(flux)-4.8*mad(flux),np.median(flux)+4.8*mad(flux))
            if i < len(filters)-1:
                ax1.set_xticklabels([])
            else:
                ax1.set_xlabel("MJD")
            axs[i-ko][0].set_yticklabels([])
            #Plot Data
            ax1.errorbar(mjd, flux , yerr=err, ls='none', marker=".", color=band_colors[i], ms=2)
            ax1.vlines(time['time'], ymin=-20, ymax=20, zorder=0, color = 'lightgrey')
            #Plot Model
            t, m, errs = models[i]
            new_m = np.interp(mjd,t, m)
            
            if grid:
                for hh in np.arange(59330,xmax,grid_step):
                    ax1.axvline(x=hh,ls='--',color='grey',alpha=0.4)
                
            ax1.set_xlim(xmin,xmax)
            
            ax1.plot(t,m, color="black", lw=3)

            if (include_slow_comp == True):
                slow_comp = slow_comps[i]
                
                ax1.plot(mjd, slow_comp(mjd)+B, linestyle="dashed", color="black")  
            filto = filter_labels[i]
            ax1.text(0.1,0.2,filto, color=band_colors[i], fontsize=19, transform=ax1.transAxes)
            ax1.fill_between(t, m+errs, m-errs, alpha=0.5, color="black")
           
            #Plot Time delay posterior distributions
            tau_samples = samples_chunks[i][2],
            axs[i-ko][1].hist(tau_samples, color=band_colors[i], bins=50,histtype='stepfilled')
            axs[i-ko][1].axvline(x = np.percentile(tau_samples, [16, 50, 84])[1], color="black")
            axs[i-ko][1].axvline(x = np.percentile(tau_samples, [16, 50, 84])[0] , color="black", ls="--")
            axs[i-ko][1].axvline(x = np.percentile(tau_samples, [16, 50, 84])[2], color="black",ls="--")
            axs[i-ko][1].axvline(x = 0, color="dodgerblue",ls="-")    
            axs[i-ko][1].set_xlabel("Lag (days)")
            axs[i-ko][1].set_yticklabels([])
            axs[i-ko][1].axes.get_yaxis().set_visible(True)
            axs[i-ko][0].set_xticklabels([])
            
            axs[0][0].set_title("Mrk 841 light curves")
            axs[0][1].set_title("CCF")


        if (filters[i] == delay_ref):
            
            if ((show_delay_ref == True)):
                gs00 = gridspec.GridSpecFromSubplotSpec(6, 1, subplot_spec=axs[i-ko][0],hspace=0)
                ax1 = fig.add_subplot(gs00[:-2, :])
                ax2 = fig.add_subplot(gs00[-2:, :])
                ax1.set_ylim(np.median(flux)-4.8*mad(flux),np.median(flux)+4.8*mad(flux))
                if i < len(filters)-1:
                    
                    ax1.set_xticklabels([])
                else:
                    ax1.set_xlabel("MJD")
                ax1.set_xticklabels([])
                axs[i-ko][0].set_yticklabels([])
                #Plot Data
                ax1.errorbar(mjd, flux , yerr=err, ls='none', marker=".", color=band_colors[i], ms=2)
                #Plot Model
                t, m, errs = models[i]
                new_m = np.interp(mjd,t, m)
                
                if grid:
                    for hh in np.arange(59330,xmax,5):
                        ax1.axvline(x=hh,ls='--',color='grey',alpha=0.4)
                    
                ax2.set_xlim(xmin,xmax)
                ax1.set_xlim(xmin,xmax)
                
                ax1.plot(t,m, color="black", lw=3)
                filto = filter_labels[i]
                ax1.text(0.1,0.2,filto, color=band_colors[i], fontsize=19, transform=ax1.transAxes)
                ax1.fill_between(t, m+errs, m-errs, alpha=0.5, color="black")
                ax1.set_ylabel(ylab)
                ax2.set_ylabel(r"$\chi$")
                #print('   --> Skipping')
                ko =0
            else:
                ko=1
                            
    for ax in axs.flat:
        ax.label_outer()    

    
    plt.savefig("pyroa_lightcurve.png", dpi=300)

def FluxFlux(objName, filters, delay_ref, gal_ref,wavelengths,
            lc_file="Lightcurve_models.obj",
            samples_file='samples_flat.obj',
            xt_file='X_t.obj',
            outputdir = './', datadir='./',
            burnin=0, band_colors = None,
            input_units='mJy',output_units='mJy',
            redshift=0.0, ebv=0.0,
            limits=None, ylab = None,
            savefig=True, figname=None,
            model=None):
    
    def unred(wave, flux, ebv, R_V=3.1, LMC2=False, AVGLMC=False):

        x = 10000./ wave # Convert to inverse microns
        curve = x*0.

        # Set some standard values:
        x0 = 4.596
        gamma =  0.99
        c3 =  3.23
        c4 =  0.41
        c2 = -0.824 + 4.717/R_V
        c1 =  2.030 - 3.007*c2

        if LMC2:
            x0	=  4.626
            gamma =  1.05
            c4   =  0.42
            c3	=  1.92
            c2	= 1.31
            c1	=  -2.16
        elif AVGLMC:
            x0 = 4.596
            gamma = 0.91
            c4   =  0.64
            c3	=  2.73
            c2	= 1.11
            c1	=  -1.28

        # Compute UV portion of A(lambda)/E(B-V) curve using FM fitting function and
        # R-dependent coefficients
        xcutuv = np.array([10000.0/2700.0])
        xspluv = 10000.0/np.array([2700.0,2600.0])

        iuv = np.where(x >= xcutuv)[0]
        N_UV = len(iuv)
        iopir = np.where(x < xcutuv)[0]
        Nopir = len(iopir)
        if (N_UV > 0): xuv = np.concatenate((xspluv,x[iuv]))
        else:  xuv = xspluv

        yuv = c1  + c2*xuv
        yuv = yuv + c3*xuv**2/((xuv**2-x0**2)**2 +(xuv*gamma)**2)
        yuv = yuv + c4*(0.5392*(np.maximum(xuv,5.9)-5.9)**2+0.05644*(np.maximum(xuv,5.9)-5.9)**3)
        yuv = yuv + R_V
        yspluv  = yuv[0:2]  # save spline points

        if (N_UV > 0): curve[iuv] = yuv[2::] # remove spline points

        # Compute optical portion of A(lambda)/E(B-V) curve
        # using cubic spline anchored in UV, optical, and IR
        xsplopir = np.concatenate(([0],10000.0/np.array([26500.0,12200.0,6000.0,5470.0,4670.0,4110.0])))
        ysplir   = np.array([0.0,0.26469,0.82925])*R_V/3.1
        ysplop   = np.array((np.polyval([-4.22809e-01, 1.00270, 2.13572e-04][::-1],R_V ),
                np.polyval([-5.13540e-02, 1.00216, -7.35778e-05][::-1],R_V ),
                np.polyval([ 7.00127e-01, 1.00184, -3.32598e-05][::-1],R_V ),
                np.polyval([ 1.19456, 1.01707, -5.46959e-03, 7.97809e-04, -4.45636e-05][::-1],R_V ) ))
        ysplopir = np.concatenate((ysplir,ysplop))

        if (Nopir > 0):
          tck = interpolate.splrep(np.concatenate((xsplopir,xspluv)),np.concatenate((ysplopir,yspluv)),s=0)
          curve[iopir] = interpolate.splev(x[iopir], tck)

        #Now apply extinction correction to input flux vector
        curve *= ebv

        return flux * 10.**(0.4*curve)

    plt.rcParams.update({
        "font.family": "Sans",  
        "font.serif": ["DejaVu"],
    "figure.figsize":[40,30],
    "font.size": 19})  

    if outputdir[-1] != '/': outputdir += '/'
    if ylab ==None: ylab = r"F$_{\nu}$"+" (mJy)"
    #if filter_labels == None: filter_labels = filters

    if input_units == 'mJy': funits = 1*u.mJy
    if input_units == 'Jy': funits = 1*u.Jy
    if input_units == 'fnu': funits = 1*u.erg/u.s/(u.cm**2)/u.Hz
    if input_units == 'flam': funits = 1*u.erg/u.s/(u.cm**2)/u.Angstrom
    if output_units == 'mJy': 
        ylab = r"F$_{\nu}$"+" (mJy)"
    if output_units == 'Jy': 
        ylab = r"F$_{\nu}$"+" (mJy)"
    if output_units == 'fnu': 
        ylab = r"F$_{\nu}$"+r" / erg s$^{-1}$ cm$^{-2}$ ${\rm Hz}^{-1}$"
    if output_units == 'flam': 
        ylab = r"F$_{\lambda}$"+r" / $\times10^{-15}$ erg s$^{-1}$ cm$^{-2}$ ${\rm \AA}^{-1}$"

    ss = np.where(np.array(filters) == delay_ref)[0][0]
    file = open(outputdir+samples_file,'rb')
    samples_flat = pickle.load(file)
    samples_flat = samples_flat[burnin:,:]
    file = open(outputdir+lc_file,'rb')
    models = pickle.load(file)
    
    file = open(outputdir+xt_file,'rb')
    norm_lc = pickle.load(file)
    wave = np.array(wavelengths)

    #Split samples into chunks, 4 per lightcurve i.e A, B, tau, sig
    chunk_size=4
    transpose_samples = np.transpose(samples_flat)
    #Insert zero where tau_0 would be 
    transpose_samples = np.insert(transpose_samples, [ss*4+2], np.array([0.0]*len(transpose_samples[1])), axis=0)
    samples_chunks = [transpose_samples[i:i + chunk_size] for i in range(0, len(transpose_samples), chunk_size)]
    

    gal_spectrum,gal_spectrum_err,fnu_f,fnu_b,slope,slope_err = [],[],[],[],[],[]
    fnu_f_err,fnu_b_err = [], []
    
    fig = plt.figure(figsize=(10,7))
    xx = np.linspace(-15,5,300)
    max_flux = 0.0
    
    kk = 0
    fac_flux = np.ones(len(wavelengths))
    for i in range(len(filters)):
        if ((filters[i] != delay_ref) ):
            
            file = datadir + objName+"_" + str(filters[i]) + ".dat"
            data = np.loadtxt(file)
            snu_mcmc = samples_chunks[i][0]
            cnu_mcmc = samples_chunks[i][1]            
            sig = np.percentile(samples_chunks[i][3], 50)

            mc_pl = np.zeros((200,xx.size))

            for lo in range(200):
                jj = int(np.random.uniform(0,snu_mcmc.size))
                mc_pl[lo] = cnu_mcmc[jj] + xx * snu_mcmc[jj]
            
            if filters[i] == gal_ref: 
                x_gal_mcmc = -cnu_mcmc/snu_mcmc
                x_gal = np.median(x_gal_mcmc)
                x_gal_error = np.std(-cnu_mcmc/snu_mcmc)
                
            gal_spectrum_mcmc = np.median(cnu_mcmc) +  (x_gal_mcmc+x_gal_mcmc.std()) * np.median(snu_mcmc)
            
            gal_spectrum.append(gal_spectrum_mcmc.mean())
            gal_spectrum_err.append(gal_spectrum_mcmc.std())
            
            fnu_f_mcmc = snu_mcmc * (np.min(norm_lc[1]) - x_gal_mcmc)
            fnu_b_mcmc = snu_mcmc * (np.max(norm_lc[1]) - x_gal_mcmc)
    
            fnu_f.append(fnu_f_mcmc.mean())
            fnu_f_err.append(fnu_f_mcmc.std())

            fnu_b.append(fnu_b_mcmc.mean())
            fnu_b_err.append(fnu_b_mcmc.std())

            slope.append(np.median(snu_mcmc))
            slope_err.append(np.std(snu_mcmc))

            lin_fit = np.median(snu_mcmc) * xx + np.median(cnu_mcmc)
            
            
            if wavelengths != None:       

                         
                if (input_units != 'flam') and (output_units !='flam'):
                    wave = wavelengths[i+kk] * u.Angstrom
                    dd = funits
                    #print(input_units,output_units)
                    if output_units != 'fnu':
                        fac_flux[i+kk] = dd.cgs.to(output_units).value
                    else:
                        fac_flux[i+kk] = dd.cgs.to('erg s^-1 cm^-2 Hz^-1').value

                if (input_units != 'flam') and (output_units =='flam'):
                    wave = wavelengths[i+kk] * u.Angstrom
                    dd = funits/(wave**2)*ct.c

                    fac_flux[i+kk] = dd.cgs.to('erg s^-1 cm^-2 Angstrom^-1').value/1e-15

                if (input_units == 'flam') and (output_units !='flam'):
                    wave = wavelengths[i+kk] * u.Angstrom
                    dd = funits/ct.c*(wave**2)
                    if output_units != 'fnu':
                        fac_flux[i+kk] = dd.cgs.to(output_units).value
                    else:
                        fac_flux[i+kk] = dd.cgs.to('erg s^-1 cm^-2 Hz^-1').value
                        
                        
            plt.fill_between(xx,(mc_pl.mean(axis=0)+mc_pl.std(axis=0))*fac_flux[i+kk],
                        (mc_pl.mean(axis=0)-mc_pl.std(axis=0))*fac_flux[i+kk],
                        color=band_colors[i],
                        alpha=0.3)
            interp_xt = np.interp(data[:,0],norm_lc[0],norm_lc[1])
            plt.errorbar(interp_xt,data[:,1]*fac_flux[i+kk],
                        yerr=np.sqrt(data[:,2]**2+sig**2)*fac_flux[i+kk],
                        color=band_colors[i],
                        ls='None',alpha=0.8)
            plt.plot(xx,lin_fit*fac_flux[i+kk],color=band_colors[i],lw=3)
            max_flux = np.max([max_flux,np.max(data[:,1]*fac_flux[i+kk])])

        else:
            kk = -1
    fnu_f = np.array(fnu_f)
    fnu_f_err = np.array(fnu_f_err)
    fnu_b = np.array(fnu_b)
    fnu_b_err = np.array(fnu_b_err)
    slope = np.array(slope)
    slope_err = np.array(slope_err)
    gal_spectrum = np.array(gal_spectrum)
    gal_spectrum_err = np.array(gal_spectrum_err)
    
    plt.axvline(x=np.median(x_gal_mcmc+x_gal_mcmc.std()),color='r',
                linestyle='-.',label=r'$X_{\rm G}$')
    plt.axvline(x=np.min(norm_lc[1]),color='k',
                linestyle='--',label=r'$X_{\rm F}$')
    plt.axvline(x=np.max(norm_lc[1]),color='grey',
                linestyle='--',label=r'$X_{\rm B}$')

    lg = plt.legend(ncol=4)
    plt.xlim(x_gal-1,3)
    #print()
    plt.ylim(-0.04*fac_flux[-1],max_flux*1.2)
    if limits != None: plt.ylim(-0.04,limits[1])
    
    plt.xlabel(r'$X_0 (t)$, Normalised driving light curve flux')
    plt.ylabel(ylab)
    plt.tight_layout()
    if savefig:
        plt.savefig("pyroafluxflux.png", dpi=300)
    if wavelengths != None:
        wave = np.array(wavelengths)
        fig = plt.figure(figsize=(10,7))
        ax = fig.add_subplot(111)
        xxx = np.arange(2000,9300)

        # AGN variability range
        plt.fill_between(wave/(1+redshift),(np.array(unred(wave,fnu_b,ebv)))*fac_flux,
                        (np.array(unred(wave,fnu_f,ebv)))*fac_flux
                         ,color='k',alpha=0.1,label='AGN variability')
        ### F_bright - F_faint
        plt.errorbar(wave/(1+redshift),(np.array(unred(wave,fnu_b,ebv)) - \
                                    np.array(unred(wave,fnu_f,ebv)))*fac_flux,
                 yerr=np.sqrt((np.array(fnu_f_err))**2 + (np.array(fnu_b_err))**2)*fac_flux,
                 marker='.',linestyle='-',color='k',
                 label=r'AGN high-low',ms=15)
        
        ### average spectrum
        plt.errorbar(wave/(1+redshift),(np.array(unred(wave,fnu_b,ebv)) + np.array(unred(wave,fnu_f,ebv)))*fac_flux/2,
                 yerr=np.sqrt((np.array(fnu_f_err))**2 + (np.array(fnu_b_err))**2)*fac_flux/2,
                 marker='.',linestyle='-',color='b',
                 label=r'AGN average',ms=15)

        ### AGN RMS
        plt.errorbar(wave/(1+redshift),np.array(unred(wave,slope,ebv))*fac_flux,
             yerr=0,marker='o',linestyle='--',color='grey',label='AGN RMS')
        
        ### Galaxy spectrum
        plt.errorbar(wave/(1+redshift),unred(wave,gal_spectrum,ebv)*fac_flux,
                 yerr=gal_spectrum_err*fac_flux,
                 marker='s',color='r',label='Galaxy contribution',linestyle='-.')
        
        a=12
        x = np.linspace(1000,10000)
        y = a*x**-(1/3)
        ax.plot(x, y, label=r"$\nu^{1/3}$", color="green")
    
        #print(fac_flux)
        plt.xscale('log')
        plt.yscale('log')
        plt.xlim(np.min(wave/(1+redshift))-100,np.max(wave/(1+redshift))+100)
        #print(np.min(np.array(unred(wave,slope,ebv)))*0.7,max_flux*1.2)
        plt.ylim(np.min(np.array(unred(wave,slope,ebv)))*0.7*fac_flux[-1],max_flux*1.2)
        if limits != None: plt.ylim(limits[0],limits[1])
        lg = plt.legend(ncol=2, loc='lower right')
        if redshift > 0:
            plt.xlabel(r'Rest Wavelength ($\mathrm{\AA}$)')
        else:
            plt.xlabel(r'Observed Wavelength / $\mathrm{\AA}$')
        plt.ylabel(ylab)

        ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.2f'))
        ax.xaxis.set_major_formatter(mtick.FormatStrFormatter('%.0f'))
        ax.xaxis.set_minor_formatter(mtick.FormatStrFormatter('%.0f'))
        ax.xaxis.set_minor_locator(ticker.MultipleLocator(2000))
        plt.tight_layout()
        
        if savefig:
            plt.savefig("pyroased.png", dpi=300)
    else:
        print(' [PyROA] No wavelength list. Skipping SED plot.')
        # Create output file from flux-flux analysis
    #print(wave)
    d = {'wave': wave,
         'mean': np.array((np.array(unred(wave,fnu_b,ebv)) + np.array(unred(wave,fnu_f,ebv)))*fac_flux/2),
         'mean_err': np.array(np.sqrt((np.array(fnu_f_err))**2 + (np.array(fnu_b_err))**2)*fac_flux/2),
         'faint':np.array(unred(wave,fnu_f,ebv))*fac_flux,
         'faint_err':np.sqrt((np.array(fnu_f_err)))}
    df = pd.DataFrame(data=d)
    df.to_csv(objName+'_MEANSPEC.csv',index=False)

In [None]:
Lightcurves(obj_name,filters,delay_ref,
                  datadir=datadir, outputdir=outputdir,
                  burnin=burnin, band_colors=band_colors,
                  grid=True, show_delay_ref=False, time=swiftxrt)

In [None]:
FluxFlux(obj_name,filters,delay_ref,gal_ref,waves,
                datadir=datadir,outputdir=outputdir,
                burnin=burnin,
                band_colors=band_colors,
                ebv=0.032,redshift=redshift,
                limits=[.1,12])

## Disc-only delay spectrum

In [None]:
#Eddington luminosity function
def lumin(M = (10**8)*1.989e30):
    G = 6.6743e-11 #Nm^2/kg^2
    # M = (10**8.05)*1.989e30 #kg
    c = 299792458 #m/s
    mp = 1.67262192*1e-27
    sigmaTh = 6.6524587321e-29 #m^2
    Ledd = 4*np.pi*G*M*c*mp/sigmaTh
    return Ledd

In [None]:
FILTERS = ["UVW2", "gp1", "gp", "V", "rp", "ip", "zs"]
filters = FILTERS
waves = [1894, 4770, 5510, 6231, 7625, 9134]
redshift = 0.0364
band_colors=['#0652DD','#1289A7','#006266','#A3CB38','orange','#EE5A24','brown']
delay_ref = 'gp1'
outputdir = "/home/jovyan/PyROA2"
burnin=150000

# following functions are adapted from Utils.py accessible at https://github.com/Alymantara/PyROA
# Reference: Donnan et al. 2021 (https://ui.adsabs.harvard.edu/abs/2021arXiv210712318D/abstract)
def DiscLagSpectrum(filters,delay_ref,wavelengths,
                burnin=0,samples_file='samples_flat.obj',
                outputdirs = './',
                band_colors = None,
                redshift=0.0,
                savefig=True,figname=None,
                dataframe=None, resume=False):

    fig, ax = plt.subplots(figsize=(10,8))
    outputdir=outputdirs
    color='k'
    marker='o'
    if outputdir[-1] != '/': outputdir += '/'
    file = open(outputdir+samples_file,'rb')
    samples = pickle.load(file)[burnin:]

    ss = np.where(np.array(filters) == delay_ref)[0][0]
    #print(ss)
    labels = []
    for i in range(len(filters)):
        for j in ["A", "B",r"$\tau$", r"$\sigma$"]:
            labels.append(j+r'$_{'+filters[i]+r'}$')
    labels.append(r'$\Delta$')
    all_labels = labels.copy()
    del labels[ss*4+2]

    # To get ONLY lags
    shifter = 2

    list_only = []
    mm = 0
    ndim = len(filters)
    for i in range(ndim):
        if i != ss:
            list_only.append(i*4+shifter+mm)
        if i == ss:
            mm = -1
    # Get the 
    lag,lag_m,lag_p = np.zeros(ndim-1),np.zeros(ndim-1),np.zeros(ndim-1)
    for j,i in enumerate(list_only):
        #print(i)
        q50 = np.percentile(samples[:,i],50)
        q84 = np.percentile(samples[:,i],84)
        q16 = np.percentile(samples[:,i],16)
        lag[j] = q50
        lag_m[j] = q50-q16
        lag_p[j] = q84-q50

    # ax = fig.add_subplot(111)

    plt.axhline(y=0,ls='--',alpha=0.5,color='k')

    if band_colors == None: band_colors = 'k'*7

    mm = 0
    for i in range(lag.size):        
        plt.errorbar(wavelengths[i]/(1+redshift),lag[i]/(1+redshift),
                    yerr=lag_m[i],marker=marker,
                    color=color, ms=15)

    if redshift > 0:
        plt.xlabel(r'Rest Wavelength / $\mathrm{\AA}$', fontsize=14)
        plt.ylabel(r'$\tau_{\rm rest}$ / day', fontsize=14)
    else:
        plt.xlabel(r'Observed Wavelength / $\mathrm{\AA}$')
        plt.ylabel(r'$\tau$ / day')
        
    delta = round(np.mean([np.median(samples[:, -1]), np.mean(samples[:, -1]), sp.stats.mode(samples[:, -1])[0]]))

    ax.lines[-1].set_label(fr'lags, $\Delta$={delta}')
        
    lam0 = dataframe.loc[(dataframe['Filter'] == 'gp')]['Wavelength'].values[0]/(1+redshift)
    
    fitwave = np.array(wavelengths)/(1+redshift)
    fitlag = np.array(lag)/(1+redshift)
    fiterr = np.array(lag_m)
    
    ax.axvline(lam0, color='k', linestyle="--", alpha=0.5, zorder=0)
    
    def alpha(theta, FluxWeight=True, lam0=lam0):
        # print(lam0)
        logmdotedd = theta[0]
        c = 299792458 #m/s
        k = 1.3806452e-23 #J/K
        lam0 = lam0 * 1e-10 #m
        h = 6.62607015e-34 #J/Hz
        G = 6.6743e-11 #Nm^2/kg^2
        M = (10**8)*1.989e30 #kg
        sigma = 5.670374419e-8 #W/m^2K^4
        eta = 0.1
        kappa = 1
        mdotedd = 10**logmdotedd #LBol/LEdd

        if FluxWeight == True:
            X = 2.49
        else:
            X=4.97

        Ledd = lumin(M)
        alpha = (1/c)*(X*k*lam0/(h*c) )**(4/3) * ( (G*M/(8*np.pi*sigma)) * (Ledd/(eta*c**2)) * (3+kappa)*mdotedd )**(1/3)
        return alpha/(60*60*24)
    
    def timelag(theta, lam=fitwave, FluxWeight=True, lam0=lam0):
        # A, x, beta, y0, logmdotedd = theta
        logmdotedd = theta[0]
        a = alpha(theta, FluxWeight=FluxWeight)
        return a*((lam/lam0)**(4/3) - 1)
    
    def model(params, lam=fitwave, lam0=lam0):
        tau0=params[0]
        return tau0*((lam/lam0)**(4/3) - 1)
    
    def model_tau(theta, params, lam=fitwave, lam0=lam0):
        tau0, syserr=theta
        beta, y0=params
        return tau0*((lam/lam0)**(beta) - y0)
    
    def model_X(theta, params, lam=fitwave, lam0=lam0):
        logmdotedd, syserr=theta
        M, beta, y0, FluxWeight=params
        a = alpha(theta, FluxWeight=FluxWeight)
        return a*((lam/lam0)**beta - y0)
    
    def residual_tau(params, x, y, yerr):
        tau0 = params['tau0'].value
        params=[tau0]
        return (y - model(params, lam=x))/yerr
    
    def residual_X(params, x, y, yerr, FluxWeight):
        # M = params['M'].value
        logmdotedd = params['logmdotedd'].value
        params = np.array([logmdotedd])
        return (y - timelag(params, lam=x, FluxWeight=FluxWeight))/yerr
    
    def lnlike_tau(theta, params, x, y, yerr):
        tau0, syserr = theta
        beta, y0 = params
        # inputs = [tau0, beta, y0, syserr]
        return -1/2 * np.sum(((y - model_tau(theta, params, x))/yerr)**2 +np.log(yerr**2+syserr**2) )

    def lnprior_tau(theta):
        tau0, syserr = theta
        if -5.0 < tau0 < 5.0 and -0 < syserr < 5.0:
            return 0.0
        return -np.inf

    def lnprob_tau(theta, params, x, y, yerr):
        lp = lnprior_tau(theta)
        if not np.isfinite(lp):
            return -np.inf
        return lp + lnlike_tau(theta, params, x, y, yerr)
    
    def lnlike_X(theta, params, x, y, yerr):
        logmdotedd, syserr = theta
        M, beta, y0, FluxWeight = params
        # inputs = [tau0, beta, y0, syserr]
        return -1/2 * np.sum(((y - model_X(theta, params, x))/yerr)**2 +np.log(yerr**2+syserr**2) )

    def lnprior_X(theta):
        logmdotedd, syserr = theta
        if np.log10(2e-2) < logmdotedd < np.log10(18e-2) and 0 < syserr < 5.0:
            return 0.0
        return -np.inf

    def lnprob_X(theta, params, x, y, yerr):
        lp = lnprior_X(theta)
        if not np.isfinite(lp):
            return -np.inf
        return lp + lnlike_X(theta, params, x, y, yerr)
    
    def fitter2(data, residual, fit_params, model, params, functions, labels, file, nwalkers=500, niter=10000):
        lnlike, lnprior, lnprob = functions
        
        result = lmft.minimize(residual, fit_params, args=data)
        
        keys = [result.params[key].value for key in result.params]
        keys.append(1) #for syserr
        
        initial = np.array(keys)
        ndim = len(initial)
        p0 = [np.array(initial) + 1e-7 * np.random.randn(ndim) for i in range(nwalkers)]
        
        if len(data) == 4:
            args = (params, *data[:-1])
        else:
            args = (params, *data)
            
        filename = f"{file}"
        backend = emcee.backends.HDFBackend(filename)
        # print(backend.iteration)
        # backend.reset(nwalkers, ndim)
            
        sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, args=args, threads=12, backend=backend)
        if resume == True:
            p0 = None
            print("true")
            pos, prob, state = sampler.run_mcmc(p0, niter, progress=True)
        elif resume == False:
            print("false")
            p0, _, _ = sampler.run_mcmc(p0, 100, progress=True)
            sampler.reset()
            pos, prob, state = sampler.run_mcmc(p0, niter, progress=True)
        
        samples = sampler.flatchain
        
        # chains(sampler, labels, ndim)
        
        print(sampler.get_autocorr_time())
        
        flat_samples = sampler.get_chain(discard=3*round(np.mean(sampler.get_autocorr_time())), thin=15, flat=True)
        
        # emcee_corner = corner.corner(flat_samples, labels=labels, show_titles=True)
        
        theta_max = samples[np.argmax(sampler.flatlnprobability)]
        
        hp = []
        stdminus = []
        stdplus = []
        for i in range(ndim):
            mcmc = np.percentile(flat_samples[:,i], [16,50,84])
            q = np.diff(mcmc)
            hp.append(mcmc[1])
            stdminus.append(q[0])
            stdplus.append(q[1])
        
        sig = np.sqrt(fiterr**2 + hp[-1])
    
        def sample_walkers(model, params, nsamples=10000, flattened_chain=flat_samples):
            models=[]
            draw = np.floor(np.random.uniform(0, len(flat_samples), size=nsamples)).astype(int)
            thetas = flattened_chain[draw]
            for i in thetas:
                mod = model(i, params)
                models.append(mod)
            spread = np.std(models, axis=0)
            med_model = np.median(models, axis=0)
            return med_model, spread
        
        med_model, spread = sample_walkers(model, params)
        # print(len(med_model), len(spread))
        
        return theta_max, stdminus, stdplus, med_model, spread, sampler, sampler.get_chain(), flat_samples
        
    
    fit_params_tau = lmft.create_params(tau0={'value':3, 'min':0, 'max':10,'vary':True})
    fit_params_X = lmft.create_params(logmdotedd={'value':-1.17, 'min':np.log10(2e-2), 'max':np.log10(18e-2), 'vary':True})
    
    data = (fitwave, fitlag, fiterr)
    labels = [r'tau_0', 'syserr']
    params = [4/3, 1]
    functions = [lnlike_tau, lnprior_tau, lnprob_tau]
    file = "bestdisc2.h5"
    name = "bestdisc2"
    resume=False
    
    keys, stdminus, stdplus, med_model, spread, sampler, samples, flat_samples = fitter2(data, residual_tau, fit_params_tau, model_tau, params, functions, labels, file)
    taumodel = [keys, stdminus, stdplus, med_model, spread, sampler, samples, flat_samples]
    
    
    df = pd.DataFrame({'keys':keys, 'stdplus':stdplus, 'stdminus':stdminus})
    df.to_pickle(f"{name}_model2.pkl")
    
    df = pd.DataFrame({'med_model':med_model, 'spread':spread})
    df.to_pickle(f"{name}_uncertainty2.pkl")
    
    values = ['tau0']
    
    df = pd.DataFrame({'syserr_step':samples[:, :, -1][0], 'syserr': samples[:, :, -1][1]})
    for i in range(len(values)):
        df.loc[:, fr'{values[i]}_step'] = samples[:, :, i][0]
        df.loc[:, fr'{values[i]}'] = samples[:, :, i][1]
    df.to_pickle(f"{name}_samples2.pkl")
    
    df = pd.DataFrame({'syserr':flat_samples[:,-1]})
    for i in range(len(values)):
        df.loc[:, fr'{values[i]}'] = flat_samples[:,i]
    df.to_pickle(f"{name}_flat_samples2.pkl")
    
    best_fit_model = model(keys)
    
    df = pd.DataFrame({'best_fit':best_fit_model})
    df.to_pickle(f"{name}_components2.pkl")
    
    
    data = (fitwave, fitlag, fiterr, True)
    labels = [r'\log{\dot{M}_{\rm Edd}}', 'syserr']
    params = [10**8*1.989e30, 4/3, 1, True]
    functions = [lnlike_X, lnprior_X, lnprob_X]
    file = "X2492.h5"
    name = "X2492"
    resume=False
    
    keys2, stdminus2, stdplus2, med_model2, spread2, sampler2, samples2, flat_samples2 = fitter2(data, residual_X, fit_params_X, model_X, params, functions, labels, file)
    X2 = [keys2, stdminus2, stdplus2, med_model2, spread2, sampler2, samples2, flat_samples2]
    
    df = pd.DataFrame({'keys':keys2, 'stdplus':stdplus2, 'stdminus':stdminus2})
    df.to_pickle(f"{name}_model2.pkl")
    
    df = pd.DataFrame({'med_model':med_model2, 'spread':spread2})
    df.to_pickle(f"{name}_uncertainty2.pkl")
    
    values = ['logmdotedd']
    
    df = pd.DataFrame({'syserr_step':samples2[:, :, -1][0], 'syserr': samples2[:, :, -1][1]})
    for i in range(len(values)):
        df.loc[:, fr'{values[i]}_step'] = samples2[:, :, i][0]
        df.loc[:, fr'{values[i]}'] = samples2[:, :, i][1]
    df.to_pickle(f"{name}_samples2.pkl")
    
    df = pd.DataFrame({'syserr':flat_samples2[:,-1]})
    for i in range(len(values)):
        df.loc[:, fr'{values[i]}'] = flat_samples2[:,i]
    df.to_pickle(f"{name}_flat_samples2.pkl")
    
    best_fit_model2 = timelag(keys2, FluxWeight=True)
    
    df = pd.DataFrame({'best_fit':best_fit_model2})
    df.to_pickle(f"{name}_components2.pkl")
    
    data = (fitwave, fitlag, fiterr, False)
    params = [10**8*1.989e30, 4/3, 1, False]
    # file = "X249.h5"
    file="X4972.h5"
    name = "X4972"
    resume=False
    
    keys4, stdminus4, stdplus4, med_model4, spread4, sampler4, samples4, flat_samples4 = fitter2(data, residual_X, fit_params_X, model_X, params, functions, labels, file)
    X4 = [keys4, stdminus4, stdplus4, med_model4, spread4, sampler4, samples4, flat_samples4]
    
    df = pd.DataFrame({'keys':keys4, 'stdplus':stdplus4, 'stdminus':stdminus4})
    df.to_pickle(f"{name}_model2.pkl")
    
    df = pd.DataFrame({'med_model':med_model4, 'spread':spread4})
    df.to_pickle(f"{name}_uncertainty2.pkl")
    
    values = ['logmdotedd']
    
    df = pd.DataFrame({'syserr_step':samples4[:, :, -1][0], 'syserr': samples4[:, :, -1][1]})
    for i in range(len(values)):
        df.loc[:, fr'{values[i]}_step'] = samples4[:, :, i][0]
        df.loc[:, fr'{values[i]}'] = samples4[:, :, i][1]
    df.to_pickle(f"{name}_samples2.pkl")
    
    df = pd.DataFrame({'syserr':flat_samples4[:,-1]})
    for i in range(len(values)):
        df.loc[:, fr'{values[i]}'] = flat_samples4[:,i]
    df.to_pickle(f"{name}_flat_samples2.pkl")
    
    best_fit_model4 = timelag(keys4, FluxWeight=False)
    
    df = pd.DataFrame({'best_fit':best_fit_model4})
    df.to_pickle(f"{name}_components2.pkl")
    
    ax.plot(fitwave, timelag(keys2, FluxWeight=True), label="X=2.49 disc expectation", color='blue')
    ax.plot(fitwave, timelag(keys4, FluxWeight=False), label="X=4.97 disc expectation", color='red')
    ax.plot(fitwave, model(keys), label=fr"Highest likelihood disc model, $\tau_0$={round(keys[0],2)}", color='green')
    
    ax.fill_between(fitwave, med_model4-spread4, med_model4+spread4, color='red', alpha=0.2, label=fr"1 $\sigma$ spread")
    ax.fill_between(fitwave, med_model2-spread2, med_model2+spread2, color='blue', alpha=0.1, label=fr"1 $\sigma$ spread")
    ax.fill_between(fitwave, med_model-spread, med_model+spread, color='green', alpha=0.2, label=fr"1 $\sigma$ spread")
    
    ax.set_xlim(min(fitwave)-200, max(fitwave)+200)
    ax.set_ylim(min(fitlag)*1.1, max(fitlag)*1.1)

    lines, labels = ax.get_legend_handles_labels()
    
    plt.legend([lines[0], lines[1], lines[2], lines[3], (lines[4], lines[5], lines[6])], 
                 labels, handler_map={tuple: HandlerTuple(ndivide=None)})
    plt.savefig("delspecdisc2.png", dpi=300)
    plt.show()
    return taumodel, X2, X4

models = DiscLagSpectrum(filters,delay_ref,outputdirs=outputdir,
                burnin=burnin,
                band_colors=band_colors,
                wavelengths=waves,redshift=redshift,
                dataframe = DF_FCCF, resume=False)

In [None]:
FILTERS = ["UVW2", "gp1", "gp", "V", "rp", "ip", "zs"]
filters = FILTERS
waves = [1894, 4770, 5510, 6231, 7625, 9134]
redshift = 0.0364
band_colors=['#0652DD','#1289A7','#006266','#A3CB38','orange','#EE5A24','brown']
delay_ref = 'gp1'
outputdir = "/home/jovyan/PyROA2"
burnin=150000

# following functions are adapted from Utils.py accessible at https://github.com/Alymantara/PyROA
# Reference: Donnan et al. 2021 (https://ui.adsabs.harvard.edu/abs/2021arXiv210712318D/abstract)
def DiscLagSpectrumPlot(filters,delay_ref,wavelengths,
                burnin=0,samples_file='samples_flat.obj',
                outputdirs = './',
                band_colors = None,
                redshift=0.0,
                savefig=True,figname=None,
                dataframe=None, resume=False, add=""):

    fig, ax = plt.subplots(figsize=(15,10))
    outputdir=outputdirs
    color='k'
    marker='o'
    if outputdir[-1] != '/': outputdir += '/'
    file = open(outputdir+samples_file,'rb')
    samples = pickle.load(file)[burnin:]

    ss = np.where(np.array(filters) == delay_ref)[0][0]
    #print(ss)
    labels = []
    for i in range(len(filters)):
        for j in ["A", "B",r"$\tau$", r"$\sigma$"]:
            labels.append(j+r'$_{'+filters[i]+r'}$')
    labels.append(r'$\Delta$')
    all_labels = labels.copy()
    del labels[ss*4+2]

    # To get ONLY lags
    shifter = 2

    list_only = []
    mm = 0
    ndim = len(filters)
    for i in range(ndim):
        if i != ss:
            list_only.append(i*4+shifter+mm)
        if i == ss:
            mm = -1
    # Get the 
    lag,lag_m,lag_p = np.zeros(ndim-1),np.zeros(ndim-1),np.zeros(ndim-1)
    for j,i in enumerate(list_only):
        #print(i)
        q50 = np.percentile(samples[:,i],50)
        q84 = np.percentile(samples[:,i],84)
        q16 = np.percentile(samples[:,i],16)
        lag[j] = q50
        lag_m[j] = q50-q16
        lag_p[j] = q84-q50

    # ax = fig.add_subplot(111)

    plt.axhline(y=0,ls='--',alpha=0.5,color='k')

    if band_colors == None: band_colors = 'k'*7

    mm = 0
    for i in range(lag.size):        
        plt.errorbar(wavelengths[i]/(1+redshift),lag[i]/(1+redshift),
                    yerr=lag_m[i],marker=marker,
                    color=color, ms=15)

    if redshift > 0:
        plt.xlabel(r'Rest Wavelength / $\mathrm{\AA}$')
        plt.ylabel(r'$\tau_{\rm rest}$ / day')
    else:
        plt.xlabel(r'Observed Wavelength / $\mathrm{\AA}$')
        plt.ylabel(r'$\tau$ / day')
        
    delta = round(np.mean([np.median(samples[:, -1]), np.mean(samples[:, -1]), sp.stats.mode(samples[:, -1])[0]]))

    ax.lines[-1].set_label(fr'lags, $\Delta$={delta}')
        
    lam0 = dataframe.loc[(dataframe['Filter'] == 'gp')]['Wavelength'].values[0]/(1+redshift)
    
    fitwave = np.array(wavelengths)/(1+redshift)
    fitlag = np.array(lag)/(1+redshift)
    fiterr = np.array(lag_m)
    
    ax.axvline(lam0, color='k', linestyle="--", alpha=0.5, zorder=0)
    
    def alpha(theta, FluxWeight=True, lam0=lam0):
        # print(lam0)
        logmdotedd = theta[0]
        c = 299792458 #m/s
        k = 1.3806452e-23 #J/K
        lam0 = lam0 * 1e-10 #m
        h = 6.62607015e-34 #J/Hz
        G = 6.6743e-11 #Nm^2/kg^2
        M = (10**8)*1.989e30 #kg
        sigma = 5.670374419e-8 #W/m^2K^4
        eta = 0.1
        kappa = 1
        mdotedd = 10**logmdotedd #LBol/LEdd

        if FluxWeight == True:
            X = 2.49
        else:
            X=4.97

        Ledd = lumin(M)
        alpha = (1/c)*(X*k*lam0/(h*c) )**(4/3) * ( (G*M/(8*np.pi*sigma)) * (Ledd/(eta*c**2)) * (3+kappa)*mdotedd )**(1/3)
        return alpha/(60*60*24)
    
    def timelag(theta, lam=fitwave, FluxWeight=True, lam0=lam0):
        # A, x, beta, y0, logmdotedd = theta
        logmdotedd = theta[0]
        a = alpha(theta, FluxWeight=FluxWeight)
        return a*((lam/lam0)**(4/3) - 1)
    
    def model(params, lam=fitwave, lam0=lam0):
        tau0=params[0]
        return tau0*((lam/lam0)**(4/3) - 1)
    
    def model_tau(theta, params, lam=fitwave, lam0=lam0):
        tau0, syserr=theta
        beta, y0=params
        return tau0*((lam/lam0)**(beta) - y0)
    
    def model_X(theta, params, lam=fitwave, lam0=lam0):
        logmdotedd, syserr=theta
        M, beta, y0, FluxWeight=params
        a = alpha(theta, FluxWeight=FluxWeight)
        return a*((lam/lam0)**beta - y0)
    
    def residual_tau(params, x, y, yerr):
        tau0 = params['tau0'].value
        params=[tau0]
        return (y - model(params, lam=x))/yerr
    
    def residual_X(params, x, y, yerr, FluxWeight):
        # M = params['M'].value
        logmdotedd = params['logmdotedd'].value
        params = np.array([logmdotedd])
        return (y - timelag(params, lam=x, FluxWeight=FluxWeight))/yerr
    
    def lnlike_tau(theta, params, x, y, yerr):
        tau0, syserr = theta
        beta, y0 = params
        # inputs = [tau0, beta, y0, syserr]
        return -1/2 * np.sum(((y - model_tau(theta, params, x))/yerr)**2 +np.log(yerr**2+syserr**2) )

    def lnprior_tau(theta):
        tau0, syserr = theta
        if -5.0 < tau0 < 5.0 and -0 < syserr < 5.0:
            return 0.0
        return -np.inf

    def lnprob_tau(theta, params, x, y, yerr):
        lp = lnprior_tau(theta)
        if not np.isfinite(lp):
            return -np.inf
        return lp + lnlike_tau(theta, params, x, y, yerr)
    
    def lnlike_X(theta, params, x, y, yerr):
        logmdotedd, syserr = theta
        M, beta, y0, FluxWeight = params
        # inputs = [tau0, beta, y0, syserr]
        return -1/2 * np.sum(((y - model_X(theta, params, x))/yerr)**2 +np.log(yerr**2+syserr**2) )

    def lnprior_X(theta):
        logmdotedd, syserr = theta
        if np.log10(2e-2) < logmdotedd < np.log10(18e-2) and 0 < syserr < 5.0:
            return 0.0
        return -np.inf

    def lnprob_X(theta, params, x, y, yerr):
        lp = lnprior_X(theta)
        if not np.isfinite(lp):
            return -np.inf
        return lp + lnlike_X(theta, params, x, y, yerr)
    
    def fitter2(data, residual, fit_params, model, params, functions, labels, file, nwalkers=500, niter=10000):
        lnlike, lnprior, lnprob = functions
        
        result = lmft.minimize(residual, fit_params, args=data)
        
        keys = [result.params[key].value for key in result.params]
        keys.append(1) #for syserr
        
        initial = np.array(keys)
        ndim = len(initial)
        p0 = [np.array(initial) + 1e-7 * np.random.randn(ndim) for i in range(nwalkers)]
        
        if len(data) == 4:
            args = (params, *data[:-1])
        else:
            args = (params, *data)
            
        filename = f"{file}"
        backend = emcee.backends.HDFBackend(filename)
        # print(backend.iteration)
        # backend.reset(nwalkers, ndim)
            
        sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, args=args, threads=12, backend=backend)
        if resume == True:
            p0 = None
            print("true")
            pos, prob, state = sampler.run_mcmc(p0, niter, progress=True)
        elif resume == False:
            print("false")
            p0, _, _ = sampler.run_mcmc(p0, 100, progress=True)
            sampler.reset()
            pos, prob, state = sampler.run_mcmc(p0, niter, progress=True)
        
        samples = sampler.flatchain
        
        # chains(sampler, labels, ndim)
        
        print(sampler.get_autocorr_time())
        
        flat_samples = sampler.get_chain(discard=3*round(np.mean(sampler.get_autocorr_time())), thin=15, flat=True)
        
        # emcee_corner = corner.corner(flat_samples, labels=labels, show_titles=True)
        
        theta_max = samples[np.argmax(sampler.flatlnprobability)]
        
        hp = []
        stdminus = []
        stdplus = []
        for i in range(ndim):
            mcmc = np.percentile(flat_samples[:,i], [16,50,84])
            q = np.diff(mcmc)
            hp.append(mcmc[1])
            stdminus.append(q[0])
            stdplus.append(q[1])
        
        sig = np.sqrt(fiterr**2 + hp[-1])
    
        def sample_walkers(model, params, nsamples=10000, flattened_chain=flat_samples):
            models=[]
            draw = np.floor(np.random.uniform(0, len(flat_samples), size=nsamples)).astype(int)
            thetas = flattened_chain[draw]
            for i in thetas:
                mod = model(i, params)
                models.append(mod)
            spread = np.std(models, axis=0)
            med_model = np.median(models, axis=0)
            return med_model, spread
        
        med_model, spread = sample_walkers(model, params)
        # print(len(med_model), len(spread))
        
        return theta_max, stdminus, stdplus, med_model, spread, sampler, sampler.get_chain(), flat_samples
        
    
    fit_params_tau = lmft.create_params(tau0={'value':3, 'min':0, 'max':10,'vary':True})
    fit_params_X = lmft.create_params(logmdotedd={'value':-1.17, 'min':np.log10(2e-2), 'max':np.log10(18e-2), 'vary':True})
    
    data = (fitwave, fitlag, fiterr)
    labels = [r'tau_0', 'syserr']
    params = [4/3, 1]
    functions = [lnlike_tau, lnprior_tau, lnprob_tau]
    file = "bestdisc2.h5"
    name = "bestdisc2"
    resume=False
    
    df = pd.read_pickle(f"{name}_model{add}.pkl")
    keys = df['keys'].values
    stdplus = df['stdplus'].values
    stdminus=df['stdminus'].values
    
    df = pd.read_pickle(f"{name}_uncertainty{add}.pkl")
    med_model = df['med_model'].values
    spread = df['spread'].values
    
    values = ['tau0', 'syserr']
    df = pd.read_pickle(f"{name}_flat_samples{add}.pkl")
    flat_samples = np.array([df[f'{v}'] for v in values]).T
    
    df = pd.read_pickle(f"{name}_samples{add}.pkl")
    samples = []
    for i in range(len(values)):
        v=values[i]
        smpl = np.array(df[f'{v}_step'], df[f'{v}'])
        samples.append(smpl)
    
    df = pd.read_pickle(f"{name}_components{add}.pkl")
    best_fit_model = df['best_fit'].values
        
    
    upper = []
    lower = []
    for i in range(len(med_model)):
        upper.append(med_model[i]+spread[i])
        lower.append(med_model[i]-spread[i])
    
    # keys, stdminus, stdplus, med_model, spread, sampler, samples, flat_samples = fitter2(data, residual_tau, fit_params_tau, model_tau, params, functions, labels, file)
    taumodel = [keys, stdminus, stdplus, med_model, spread, samples, flat_samples]
    
    
    data = (fitwave, fitlag, fiterr, True)
    labels = [r'\log{\dot{M}_{\rm Edd}}', 'syserr']
    params = [10**8*1.989e30, 4/3, 1, True]
    functions = [lnlike_X, lnprior_X, lnprob_X]
    file = "X2492.h5"
    name = "X2492"
    resume=False
    
    df = pd.read_pickle(f"{name}_model{add}.pkl")
    keys2 = df['keys'].values
    stdplus2 = df['stdplus'].values
    stdminus2=df['stdminus'].values
    
    df = pd.read_pickle(f"{name}_uncertainty{add}.pkl")
    med_model2 = df['med_model'].values
    spread2 = df['spread'].values
    
    values = ['logmdotedd', 'syserr']
    df = pd.read_pickle(f"{name}_flat_samples{add}.pkl")
    flat_samples2 = np.array([df[f'{v}'] for v in values]).T
    
    df = pd.read_pickle(f"{name}_samples{add}.pkl")
    samples2 = []
    for i in range(len(values)):
        v=values[i]
        smpl = np.array(df[f'{v}_step'], df[f'{v}'])
        samples2.append(smpl)
    
    df = pd.read_pickle(f"{name}_components{add}.pkl")
    best_fit_model2 = df['best_fit'].values
    
    upper2 = []
    lower2 = []
    for i in range(len(med_model2)):
        upper2.append(med_model2[i]+spread2[i])
        lower2.append(med_model2[i]-spread2[i])
    
    # keys2, stdminus2, stdplus2, med_model2, spread2, sampler2, samples2, flat_samples2 = fitter2(data, residual_X, fit_params_X, model_X, params, functions, labels, file)
    X2 = [keys2, stdminus2, stdplus2, med_model2, spread2, samples2, flat_samples2]
    
    data = (fitwave, fitlag, fiterr, False)
    params = [10**8*1.989e30, 4/3, 1, False]
    # file = "X249.h5"
    file="X4972.h5"
    name = "X4972"
    resume=False
    
    df = pd.read_pickle(f"{name}_model{add}.pkl")
    keys4 = df['keys'].values
    stdplus4 = df['stdplus'].values
    stdminus4=df['stdminus'].values
    
    df = pd.read_pickle(f"{name}_uncertainty{add}.pkl")
    med_model4 = df['med_model'].values
    spread4 = df['spread'].values
    
    values = ['logmdotedd', 'syserr']
    df = pd.read_pickle(f"{name}_flat_samples{add}.pkl")
    flat_samples4 = np.array([df[f'{v}'] for v in values]).T
    
    df = pd.read_pickle(f"{name}_samples{add}.pkl")
    samples4 = []
    for i in range(len(values)):
        v=values[i]
        smpl = np.array(df[f'{v}_step'], df[f'{v}'])
        samples4.append(smpl)
    
    df = pd.read_pickle(f"{name}_components{add}.pkl")
    best_fit_model4 = df['best_fit'].values
    
    upper4 = []
    lower4 = []
    for i in range(len(med_model4)):
        upper4.append(med_model4[i]+spread4[i])
        lower4.append(med_model4[i]-spread4[i])
    
    # keys4, stdminus4, stdplus4, med_model4, spread4, sampler4, samples4, flat_samples4 = fitter2(data, residual_X, fit_params_X, model_X, params, functions, labels, file)
    X4 = [keys4, stdminus4, stdplus4, med_model4, spread4, samples4, flat_samples4]
    
    ax.plot(fitwave, best_fit_model2, label="X=2.49 disc expectation", color='blue')
    ax.plot(fitwave, best_fit_model4, label="X=4.97 disc expectation", color='red')
    ax.plot(fitwave, best_fit_model, label=fr"Highest likelihood disc model, $\tau_0$={round(keys[0],2)}", color='green')
    
    ax.fill_between(fitwave, lower4, upper4, color='red', alpha=0.2, label=fr"1 $\sigma$ spread")
    ax.fill_between(fitwave, lower2, upper2, color='blue', alpha=0.1, label=fr"1 $\sigma$ spread")
    ax.fill_between(fitwave, lower, upper, color='green', alpha=0.2, label=fr"1 $\sigma$ spread")
    
    ax.set_xlim(min(fitwave)-200, max(fitwave)+200)
    ax.set_ylim(min(fitlag)*1.1, max(fitlag)*1.1)

    lines, labels = ax.get_legend_handles_labels()
    
    plt.legend([lines[0], lines[1], lines[2], lines[3], (lines[4], lines[5], lines[6])], 
                 labels, handler_map={tuple: HandlerTuple(ndivide=None)}, ncol=3, bbox_to_anchor=(1.25,-0.2), fontsize=25)
    plt.savefig("delspecdisc2.png", dpi=300, bbox_inches="tight")
    plt.show()
    return taumodel, X2, X4

models = DiscLagSpectrumPlot(filters,delay_ref,outputdirs=outputdir,
                burnin=burnin,
                band_colors=band_colors,
                wavelengths=waves,redshift=redshift,
                dataframe = DF_FCCF, resume=False, add=2)

In [None]:
dfs = []
dfname = ["X249", "X497", "best"]
for i in range(len(models)):
    model = models[i]
    df = pd.DataFrame([{'theta':model[0], 'stdminus':model[1], 'stdplus':model[2], 'medmod':model[3], 'spread':model[4], 'samples':model[-2], 'flat_samples':model[-1]}])
    df.to_csv(fr"{dfname[i]}.csv")
    dfs.append(df)

In [None]:
def tauCompare(filters,delay_ref,wavelengths,
                burnin=0,samples_file='samples_flat.obj',
                outputdirs = './',
                band_colors = None,
                redshift=0.0,
                savefig=True,figname=None,
                dataframe=None, database=None, theta=None,flat_samples=None, params=None):
    outputdir=outputdirs
    if outputdir[-1] != '/': outputdir += '/'
    file = open(outputdir+samples_file,'rb')
    samples = pickle.load(file)[burnin:]
    ndim = len(filters)
    
    ss = np.where(np.array(filters) == delay_ref)[0][0]
    
    shifter = 2

    list_only = []
    mm = 0
    ndim = len(filters)
    for i in range(ndim):
        if i != ss:
            list_only.append(i*4+shifter+mm)
        if i == ss:
            mm = -1
   
    lag,lag_m,lag_p = np.zeros(ndim-1),np.zeros(ndim-1),np.zeros(ndim-1)
    for j,i in enumerate(list_only):
        #print(i)
        q50 = np.percentile(samples[:,i],50)
        q84 = np.percentile(samples[:,i],84)
        q16 = np.percentile(samples[:,i],16)
        lag[j] = q50
        lag_m[j] = q50-q16
        lag_p[j] = q84-q50
        
    lam0 = dataframe.loc[(dataframe['Filter'] == 'gp')]['Wavelength'].values[0]/(1+redshift)
    
    fitwave = np.array(wavelengths)/(1+redshift)
    fitlag = np.array(lag)/(1+redshift)
    fiterr = np.array(lag_m)
    
    def alpha(theta, FluxWeight=True, lam0=lam0):
        logmdotedd = theta[0]
        c = 299792458 #m/s
        k = 1.3806452e-23 #J/K
        lam0 = lam0 * 1e-10 #m
        h = 6.62607015e-34 #J/Hz
        G = 6.6743e-11 #Nm^2/kg^2
        M = (10**8)*1.989e30 #kg
        sigma = 5.670374419e-8 #W/m^2K^4
        eta = 0.1
        kappa = 1
        mdotedd = 10**logmdotedd #LBol/LEdd

        if FluxWeight == True:
            X = 2.49
        else:
            X=4.97

        Ledd = lumin(M)
        alpha = (1/c)*(X*k*lam0/(h*c) )**(4/3) * ( (G*M/(8*np.pi*sigma)) * (Ledd/(eta*c**2)) * (3+kappa)*mdotedd )**(1/3)
        return alpha/(60*60*24)
    
    def model_X(theta, params, lam=fitwave, lam0=lam0):
        logmdotedd, syserr=theta
        M, beta, y0, FluxWeight=params
        a = alpha(theta, FluxWeight=FluxWeight)
        return a
    
    def sample_walkers(model, params, nsamples=100, flattened_chain=flat_samples):
        models=[]
        draw = np.floor(np.random.uniform(0, len(flat_samples), size=nsamples)).astype(int)
        thetas = flattened_chain[draw]
        for i in thetas:
            mod = model(i, params)
            models.append(mod)
        spread = np.std(models, axis=0)
        med_model = np.median(models, axis=0)
        return med_model, spread
    
    med, spread = sample_walkers(model_X, params)
    
    print('%7.3f +- %7.3f days'%(med, spread))
    
    return med, spread

In [None]:
#X=2.49
params = [10**8*1.989e30, 4/3, 1, True]
med2, spread2 = tauCompare(filters,delay_ref,outputdirs=outputdir,
           burnin=burnin,
           band_colors=band_colors,
           wavelengths=waves,redshift=redshift,
           dataframe = DF_FCCF, theta = models[1][0][0], flat_samples = models[1][-1], params=params)

#X=4.97
params = [10**8*1.989e30, 4/3, 1, False]
med4, spread4 = tauCompare(filters,delay_ref,outputdirs=outputdir,
           burnin=burnin,
           band_colors=band_colors,
           wavelengths=waves,redshift=redshift,
           dataframe = DF_FCCF, theta = models[2][0][0], flat_samples = models[2][-1], params=params)

In [None]:
#Observed tau0
real = dfs[0]['theta'][0][0]
minus = dfs[0]['stdminus'][0][0]
plus = dfs[0]['stdplus'][0][0]
print('%7.3f - %7.1f + %7.3f days'%(real, minus, plus))

In [None]:
#ratios
print(real/med2, real/med2 * np.sqrt((minus/real)**2 + (spread2/med2)**2), real/med4, real/med4 * np.sqrt((minus/real)**2 + (spread4/med4)**2))

## disc+BLR lag spectrum

In [None]:
# following functions are adapted from Utils.py accessible at https://github.com/Alymantara/PyROA
# Reference: Donnan et al. 2021 (https://ui.adsabs.harvard.edu/abs/2021arXiv210712318D/abstract)
def autocorr(samples_flat):

    chain = samples_flat[:,:,0].T
  # Compute the estimators for a few different chain lengths
    N = np.exp(np.linspace(np.log(init_chain_length), np.log(chain.shape[0]), 10)).astype(int)
    gw2010 = np.empty(len(N))
    new = np.empty(len(N))
    for i, n in enumerate(N):
            gw2010[i] = autocorr_gw2010(chain[:, :n])
            new[i] = autocorr_new(chain[:, :n])

    fig = plt.figure(figsize=(8,6))
    # Plot the comparisons
    plt.loglog(N, gw2010, "o-", label="G&W 2010")
    plt.loglog(N, new, "o-", label="new")
    ylim = plt.gca().get_ylim()
    plt.plot(N, N / 50., "--k", label=r"$\tau = N/50$")
    plt.ylim(ylim)
    plt.xlabel("number of samples, $N$")
    plt.ylabel(r"$\tau$ estimates")
    plt.legend(fontsize=14)


# Automated windowing procedure following Sokal (1989)
def auto_window(taus, c):
    m = np.arange(len(taus)) < c * taus
    if np.any(m):
            return np.argmin(m)
    return len(taus) - 1


# Following the suggestion from Goodman & Weare (2010)
def autocorr_gw2010(y, c=5.0):
    f = autocorr_func_1d(np.mean(y, axis=0))
    taus = 2.0 * np.cumsum(f) - 1.0
    window = auto_window(taus, c)
    return taus[window]


def autocorr_new(y, c=5.0):
    f = np.zeros(y.shape[1])
    for yy in y:
            f += autocorr_func_1d(yy)
    f /= len(y)
    taus = 2.0 * np.cumsum(f) - 1.0
    window = auto_window(taus, c)
    return taus[window]

def next_pow_two(n):
    i = 1
    while i < n:
            i = i << 1
    return i


def autocorr_func_1d(x, norm=True):
    x = np.atleast_1d(x)
    if len(x.shape) != 1:
            raise ValueError("invalid dimensions for 1D autocorrelation function")
    n = next_pow_two(len(x))

    # Compute the FFT and then (from that) the auto-correlation function
    f = np.fft.fft(x - np.mean(x), n=2 * n)
    acf = np.fft.ifft(f * np.conjugate(f))[: len(x)].real
    acf /= 4 * n

    # Optionally normalize
    if norm:
            acf /= acf[0]

    return acf

In [None]:
#Eddington luminosity function
def lumin(M = (10**8)*1.989e30):
    G = 6.6743e-11 #Nm^2/kg^2
    # M = (10**8.05)*1.989e30 #kg
    c = 299792458 #m/s
    mp = 1.67262192*1e-27
    sigmaTh = 6.6524587321e-29 #m^2
    Ledd = 4*np.pi*G*M*c*mp/sigmaTh
    return Ledd*1e7 #erg

In [None]:
FILTERS = ["UVW2", "gp1", "gp", "V", "rp", "ip", "zs"]
filters = FILTERS
waves = [1894, 4770, 5510, 6231, 7625, 9134]
redshift = 0.0364
band_colors=['#0652DD','#1289A7','#006266','#A3CB38','orange','#EE5A24','brown']
delay_ref = 'gp1'
outputdir = "/home/jovyan/PyROA2/"
# AS5101/samples_flat.obj
burnin=150000
def BLRLagTrue(filters,delay_ref,wavelengths,
                burnin=0,samples_file='samples_flat.obj',
                outputdir = './',
                band_colors = None,
                redshift=0.0,
                savefig=True,figname=None,
                dataframe=None, resume=False):
    
    if outputdir[-1] != '/': outputdir += '/'
    file = open(outputdir+samples_file,'rb')
    samples = pickle.load(file)[burnin:]

    ss = np.where(np.array(filters) == delay_ref)[0][0]
    #print(ss)
    labels = []
    for i in range(len(filters)):
        for j in ["A", "B",r"$\tau$", r"$\sigma$"]:
            labels.append(j+r'$_{'+filters[i]+r'}$')
    labels.append(r'$\Delta$')
    all_labels = labels.copy()
    del labels[ss*4+2]

    # To get ONLY lags
    shifter = 2

    list_only = []
    mm = 0
    ndim = len(filters)
    for i in range(ndim):
        if i != ss:
            list_only.append(i*4+shifter+mm)
        if i == ss:
            mm = -1
    # Get the 
    lag,lag_m,lag_p = np.zeros(ndim-1),np.zeros(ndim-1),np.zeros(ndim-1)
    for j,i in enumerate(list_only):
        #print(i)
        q50 = np.percentile(samples[:,i],50)
        q84 = np.percentile(samples[:,i],84)
        q16 = np.percentile(samples[:,i],16)
        lag[j] = q50
        lag_m[j] = q50-q16
        lag_p[j] = q84-q50
    fig, ax = plt.subplots(2, figsize=(17,12), height_ratios=[5,1],sharex=True)
    plt.subplots_adjust(wspace=0, hspace=0)

    ax[0].axhline(y=0,ls='--',alpha=0.5,color='k')

    if band_colors == None: band_colors = 'k'*7
    xerr = [584.89/2, 1262.68/2, 840/2, 1149.52/2, 1238.95/2, 994.39/2]
    mm = 0
    for i in range(lag.size):        
        ax[0].errorbar(wavelengths[i]/(1+redshift),lag[i]/(1+redshift),
                    yerr=lag_m[i], xerr=xerr[i],marker='o',
                    color='k', zorder=100)

    if redshift > 0:
        ax[1].set_xlabel(r'Rest Wavelength ($\mathrm{\AA})$')
        ax[0].set_ylabel(r'$\tau_{\rm rest}$ (day)')
    else:
        ax[1].set_xlabel(r'Observed Wavelength ($\mathrm{\AA})$')
        ax[0].set_ylabel(r'$\tau$ / day')
        
    delta = round(np.mean([np.median(samples[:, -1]), np.mean(samples[:, -1]), sp.stats.mode(samples[:, -1])[0]]))
        
    ax[0].lines[-1].set_label(fr'lags, $\Delta$={round(delta,2)}')
    
    lam0 = dataframe.loc[(dataframe['Filter'] == 'gp')]['Wavelength'].values[0]/(1+redshift)
    
    ax[0].axvline(lam0, color='k', linestyle="--", alpha=0.5, zorder=0)

    fitwave = np.array(wavelengths)/(1+redshift)
    fitlag = np.array(lag)/(1+redshift)
    fiterr = np.array(lag_m)
    BLRX = BLR_DF['X'].values/(1+redshift)
    BLRY = BLR_DF['Y'].values
    XRYX = np.linspace(1, max(BLRX),1000)
    
    def alpha(theta, FluxWeight=True, lam0=1):
        # print("alpha")
        
        A, x, beta, y0, logmdotedd = theta
        
        c = 299792458 #m/s
        k = 1.3806452e-23 #J/K
        lam0 = lam0*1e-10#4770*1e-10 #m
        h = 6.62607015e-34 #J/Hz
        G = 6.6743e-11 #Nm^2/kg^2
        M = (10**8.00)*1.989e30 #kg
        sigma = 5.670374419e-8 #W/m^2K^4
        eta = 0.1
        kappa = 1
        mdotedd = 10**logmdotedd
        # mdotedd = 9e-2 #LBol/LEdd; from https://academic.oup.com/mnras/article/419/3/2529/1069495
        Ledd = lumin(M)/1e7 #W
        
        if FluxWeight == True:
            X = 2.49
        else:
            X=4.97
            
        alpha = (1/c)*( X*k*lam0/(h*c) )**(4/3) * ( (G*M/(8*np.pi*sigma)) * (Ledd/(eta*c**2)) * (3+kappa)*mdotedd )**(1/3)
        return alpha/(60*60*24)
    
    
    def timelag(theta, lam=fitwave, lam0=1):
        # print("timelag")
        A, x, beta, y0, logmdotedd = theta
        
        alph = alpha(theta)
        return alph*((lam/lam0)**(beta) - y0), alph
    
    #for calculating the lag at 4770
    def model0(theta, lam=lam0, lam0=lam0, BLRY=BLRY):
        # print(0)
        A, x, beta, y0, logmdotedd = theta
       
        powerlaw = timelag(theta, lam=lam)[0] * ((1-x)/(1-A*x))
        # print("powerlaw0", powerlaw)
        # powerlaw = tau0*((lam/lam0)**(beta)-y0)
        
        mdotedd=10**logmdotedd
        L841 = lumin()*mdotedd
        L5548 = lumin(M = 14.22*1e30*1e7)*0.03
        # L841 = 5.1720e+43
        BLRY = BLRY*(L841/L5548)**0.5
        BLRY = BLRY*(((1-A)*x)/(1-A*x))
        tck = sp.interpolate.splrep(BLRX, BLRY, k=1)
        BLR = sp.interpolate.splev(lam, tck, ext=1)
        total = powerlaw+BLR
        # print("end0", total)
        return total
    
    #for comparison against wavelength
    def model(theta, lam=fitwave, lam0=lam0, idx=0):
        # print("mod")
        total = modelcal(theta)[idx]
        tck = sp.interpolate.splrep(XRYX, total, k=1)
        result = sp.interpolate.splev(lam, tck)
        # print("end mod")
        return result
    
    def model3(theta, lam=fitwave, lam0=lam0):
        A, X, beta, y0, logmdotedd, syserr = theta
        inputs = [A, X, beta, y0, logmdotedd]
        total = modelcal(inputs)[0]
        tck = sp.interpolate.splrep(XRYX, total, k=1)
        result = sp.interpolate.splev(lam, tck)
        return result
    
    #for cal
    def modelcal(theta, lam=fitwave, lam0=lam0, BLRY=BLRY):
        # print("cal")
        A, x, beta, y0, logmdotedd = theta

        reference = model0(theta)
        # print("ref", reference)
        powerlaw = timelag(theta)[0] * ((1-x)/(1-A*x))
        tck = sp.interpolate.splrep(lam, powerlaw)
        attempt = sp.interpolate.splev(XRYX, tck)
        # print("power", attempt)
        mdotedd=10**logmdotedd
        L841 = lumin()*mdotedd
        L5548 = lumin(M = 14.22*1e30*1e7)*0.03
        # print(L841/L5548)
        BLRY = BLRY*(L841/L5548)**0.5
        # print(BLRY)
        BLRY = BLRY*(((1-A)*x)/(1-A*x))
        # print(BLRY)
        tck = sp.interpolate.splrep(BLRX, BLRY,k=1)
        BLRY = sp.interpolate.splev(XRYX, tck, ext=1)
        # print("BLRY", BLRY)
        total = attempt+BLRY - reference
        # print("endcal", total)
        return total, attempt+BLRY
    
    def modelplt(theta, lam=fitwave, lam0=lam0, BLRY=BLRY):
        A, x, beta, y0, logmdotedd = theta

        reference = model0(theta)
        powerlaw, alph = timelag(theta)
        powerlaw = powerlaw * ((1-x)/(1-A*x))
        tck = sp.interpolate.splrep(lam, powerlaw)
        attempt = sp.interpolate.splev(XRYX, tck)
        
        mdotedd=10**logmdotedd
        L841 = lumin()*mdotedd
        L5548 = lumin(M = 14.22*1e30*1e7)*0.03
        BLRY = BLRY*(L841/L5548)**0.5
        BLRY = BLRY*(((1-A)*x)/(1-A*x))# - BLR_corr
        tck = sp.interpolate.splrep(BLRX, BLRY,k=1)
        BLRY = sp.interpolate.splev(XRYX, tck, ext=1)
                                       
        total = attempt+BLRY - reference
        return total, attempt-reference, BLRY-reference, reference, attempt, BLRY
    
    def residual(theta, x, y, yerr):
        # print("res")
        A = theta['A'].value
        X = theta['X'].value
        beta = theta['beta'].value
        y0 = theta['y0'].value
        logmdotedd = theta['logmdotedd'].value
        params = np.array([A, X, beta, y0, logmdotedd])
        res = (y - model(params, x))/yerr
        # print("endres", res)
        return res
    
    def residualplt(theta, x, y, yerr):
        res = (y - model(theta, x))/yerr
        reserr = 1
        return res, reserr
    
    def lnlike(theta, x, y, yerr):
        A, X, beta, y0, logmdotedd, syserr = theta
        return -1/2 * np.sum(((y - model3(theta, x))/yerr)**2 +np.log(yerr**2+syserr**2) )

    def lnprior(theta):
        A, X, beta, y0, logmdotedd, syserr = theta
        if 0 < A < 1 and 0 < X < 1 and 4/3-1e-1 < beta < 4/3+1e-1 and 0 < y0 < 1 and np.log10(2e-2) < logmdotedd < np.log10(18e-2) and 0 < syserr < 5.0:
            return 0.0
        return -np.inf

    def lnprob(theta, x, y, yerr):
        lp = lnprior(theta)
        if not np.isfinite(lp):
            return -np.inf
        return lp + lnlike(theta, x, y, yerr)
 
    def fitter2(data, residual, fit_params, model, labels, nwalkers=500, niter=200000):
        
        result = lmft.minimize(residual, fit_params, args=data)
        # display(result.params)
        keys = [result.params[key].value for key in result.params]
        # print(keys)
        keys.append(1) #add syserr
        # print(keys)
        # params = [keys[0], 10**8*1.989e30, 4/3, 1]
        
        initial = np.array(keys)
        ndim = len(initial)
        p0 = [np.array(initial) + 1e-7 * np.random.randn(ndim) for i in range(nwalkers)]
        
        filename = f"blr_sample.h5"
        backend = emcee.backends.HDFBackend(filename)
        # backend.reset(nwalkers, ndim)
            
        sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, args=data, threads=12, backend=backend)
        if resume == True:
            p0 = None
            pos, prob, state = sampler.run_mcmc(p0, 1, progress=True)
        else:
            p0, _, _ = sampler.run_mcmc(p0, 100, progress=True)
            sampler.reset()
            pos, prob, state = sampler.run_mcmc(p0, niter, progress=True)
        
        samples = sampler.flatchain
        
        # chains(sampler, labels, ndim)
        
        print(sampler.get_autocorr_time())
        
        flat_samples = sampler.get_chain(discard=3*round(np.mean(sampler.get_autocorr_time())), thin=15, flat=True)
        
        # emcee_corner = corner.corner(flat_samples, labels=labels, show_titles=True)
        
        theta_max = samples[np.argmax(sampler.flatlnprobability)]
        
        hp = []
        stdminus = []
        stdplus = []
        for i in range(ndim):
            mcmc = np.percentile(flat_samples[:,i], [16,50,84])
            q = np.diff(mcmc)
            hp.append(mcmc[1])
            stdminus.append(q[0])
            stdplus.append(q[1])
        
        sig = np.sqrt(fiterr**2 + hp[-1])
    
        def sample_walkers(nsamples=10000, flattened_chain=flat_samples):
            models=[]
            powers=[]
            BLRS=[]
            refs=[]
            draw = np.floor(np.random.uniform(0, len(flat_samples), size=nsamples)).astype(int)
            thetas = flattened_chain[draw]
            for i in thetas:
                best_fit_model, power_model, BLR_model, reference, power, BLR = modelplt(i[:-1])
                models.append(best_fit_model)
                powers.append(power_model)
                BLRS.append(BLR_model)
                refs.append(reference)
            values = [models, powers, BLRS, refs]
            spread = [np.std(value, axis=0) for value in values]
            med_model = [np.median(value, axis=0) for value in values]
            return med_model, spread
        
        med_model, spread = sample_walkers()
        
        return theta_max, stdminus, stdplus, med_model, spread, sampler, sampler.get_chain(), flat_samples
    
    
    
    fit_params = lmft.create_params(A={'value':0.5, 'min':0, 'max':1,'vary':True},
                                    X={'value':0.6, 'min':0, 'max':1,'vary':True},
                                    beta={'value':4/3, 'min':4/3-1e-1, 'max':4/3+1e-1,'vary':True},
                                    y0={'value':0.94409927, 'min':0, 'max':1, 'vary':True},
                                    logmdotedd={'value':-1.17, 'min':np.log10(2e-2), 'max':np.log10(18e-2), 'vary':True})
    
    data = (fitwave, fitlag, fiterr)
    
    keys, stdminus, stdplus, med_model, spread, sampler, samples, flat_samples = fitter2(data, residual, fit_params, model3, labels)
    bestmodel = [keys, stdminus, stdplus, med_model, spread, sampler, samples, flat_samples]
    
    best_fit_model, power_model, BLR_model, reference, power, BLR = modelplt(keys[:-1])
    
    allcomp = [best_fit_model, power_model, BLR_model, reference, power, BLR]
    
    upper = []
    lower = []
    for i in range(len(med_model)):
        upper.append(med_model[i]+spread[i])
        lower.append(med_model[i]-spread[i])
    
    ax[0].plot(XRYX, best_fit_model, label="disc+BLR", color="orange", zorder=5)
    ax[0].plot(XRYX, power_model,linestyle="--", color="deeppink", alpha=0.5, zorder=0, label="disc component")
    ax[0].plot(XRYX, BLR_model,linestyle="--", color="dodgerblue", alpha=0.5, zorder=0, label="BLR component")
    # ax[0].axhline(reference, linestyle="dashdot", color="indigo", alpha=0.5, zorder=0, label=fr"$\lambda_0$ offset")
    
    ax[0].fill_between(XRYX, upper[0], lower[0], color="orange", alpha=0.2, zorder=0, label=fr"$1\sigma$ spread")
    ax[0].fill_between(XRYX, upper[1], lower[1], color="deeppink", alpha=0.2, zorder=0, label=fr"$1\sigma$ spread")
    ax[0].fill_between(XRYX, upper[2], lower[2], color="dodgerblue", alpha=0.2, zorder=0, label=fr"$1\sigma$ spread")
    # ax[0].fill_between(XRYX, upper[3], lower[3], color="indigo", alpha=0.2, zorder=0, label=fr"$1\sigma$ spread")
    
    ax[1].set_ylabel(fr'$\chi$')
    res, reserr = residualplt(keys[:-1], fitwave, fitlag, fiterr)
    ax[1].errorbar(fitwave, res,yerr=reserr, marker='o', color="k", ls='none')
    ax[1].axhline(0, color='k', linestyle="--", alpha=0.5, zorder=0)
    
    lines, labels = ax[0].get_legend_handles_labels()
    
    ax[0].legend([lines[0], lines[1], lines[2], lines[3], (lines[4], lines[5], lines[6])], 
                 labels, handler_map={tuple: HandlerTuple(ndivide=None)})
    ax[0].yaxis.grid(True, which='minor')
    plt.minorticks_on()
    
    df = pd.DataFrame({'keys':keys, 'stdplus':stdplus, 'stdminus':stdminus})
    df.to_pickle(f"BLR_model.pkl")
    
    df = pd.DataFrame({'med_model':med_model, 'spread':spread})
    df.to_pickle(f"BLR_uncertainty.pkl")
    
    values = ['A', 'X', 'beta', 'y0', 'logmdotedd']
    
    df = pd.DataFrame({'syserr_step':samples[:, :, -1][0], 'syserr': samples[:, :, -1][1]})
    for i in range(len(values)):
        df.loc[:, fr'{values[i]}_step'] = samples[:, :, i][0]
        df.loc[:, fr'{values[i]}'] = samples[:, :, i][1]
    df.to_pickle(f"BLR_samples.pkl")
    
    df = pd.DataFrame({'syserr':flat_samples[:,-1]})
    for i in range(len(values)):
        df.loc[:, fr'{values[i]}'] = flat_samples[:,i]
    df.to_pickle(f"BLR_flat_samples.pkl")
    
    df = pd.DataFrame({'best_fit':best_fit_model, 'power_model':power_model, 'BLR_model':BLR_model, 'power':power, 'BLR': BLR})
    df.to_pickle(f"BLR_components.pkl")
    
    df = pd.DataFrame({'res':res, 'reserr':reserr})
    df.to_pickle(f"BLR_residuals.pkl")
    
    plt.savefig("BLR_fig.png", dpi=300)
    
    return bestmodel, allcomp, [res, reserr]

outputdir = "/home/jovyan/PyROA2/"
blr, comp, residuals = BLRLagTrue(filters,delay_ref,outputdir=outputdir,
                                burnin=burnin,
                                band_colors=band_colors,
                                wavelengths=waves,redshift=redshift,
                                dataframe = DF_FCCF, resume = True)

In [None]:
FILTERS = ["UVW2", "gp1", "gp", "V", "rp", "ip", "zs"]
filters = FILTERS
waves = [1894, 4770, 5510, 6231, 7625, 9134]
redshift = 0.0364
band_colors=['#0652DD','#1289A7','#006266','#A3CB38','orange','#EE5A24','brown']
delay_ref = 'gp1'
outputdir = "/home/jovyan/PyROA2/"
# AS5101/samples_flat.obj
burnin=150000
def BLRLagFalse(filters,delay_ref,wavelengths,
                burnin=0,samples_file='samples_flat.obj',
                outputdir = './',
                band_colors = None,
                redshift=0.0,
                savefig=True,figname=None,
                dataframe=None, resume=False):
    
    if outputdir[-1] != '/': outputdir += '/'
    file = open(outputdir+samples_file,'rb')
    samples = pickle.load(file)[burnin:]

    ss = np.where(np.array(filters) == delay_ref)[0][0]
    #print(ss)
    labels = []
    for i in range(len(filters)):
        for j in ["A", "B",r"$\tau$", r"$\sigma$"]:
            labels.append(j+r'$_{'+filters[i]+r'}$')
    labels.append(r'$\Delta$')
    all_labels = labels.copy()
    del labels[ss*4+2]

    # To get ONLY lags
    shifter = 2

    list_only = []
    mm = 0
    ndim = len(filters)
    for i in range(ndim):
        if i != ss:
            list_only.append(i*4+shifter+mm)
        if i == ss:
            mm = -1
    # Get the 
    lag,lag_m,lag_p = np.zeros(ndim-1),np.zeros(ndim-1),np.zeros(ndim-1)
    for j,i in enumerate(list_only):
        #print(i)
        q50 = np.percentile(samples[:,i],50)
        q84 = np.percentile(samples[:,i],84)
        q16 = np.percentile(samples[:,i],16)
        lag[j] = q50
        lag_m[j] = q50-q16
        lag_p[j] = q84-q50
    fig, ax = plt.subplots(2, figsize=(17,12), height_ratios=[5,1],sharex=True)
    plt.subplots_adjust(wspace=0, hspace=0)

    ax[0].axhline(y=0,ls='--',alpha=0.5,color='k')

    if band_colors == None: band_colors = 'k'*7
    xerr = [584.89/2, 1262.68/2, 840/2, 1149.52/2, 1238.95/2, 994.39/2]
    mm = 0
    for i in range(lag.size):        
        ax[0].errorbar(wavelengths[i]/(1+redshift),lag[i]/(1+redshift),
                    yerr=lag_m[i], xerr=xerr[i],marker='o',
                    color='k', zorder=100)

    if redshift > 0:
        ax[1].set_xlabel(r'Rest Wavelength ($\mathrm{\AA})$')
        ax[0].set_ylabel(r'$\tau_{\rm rest}$ (day)')
    else:
        ax[1].set_xlabel(r'Observed Wavelength ($\mathrm{\AA})$')
        ax[0].set_ylabel(r'$\tau$ / day')
        
    delta = round(np.mean([np.median(samples[:, -1]), np.mean(samples[:, -1]), sp.stats.mode(samples[:, -1])[0]]))
        
    ax[0].lines[-1].set_label(fr'lags, $\Delta$={round(delta,2)}')
    
    lam0 = dataframe.loc[(dataframe['Filter'] == 'gp')]['Wavelength'].values[0]/(1+redshift)
    
    ax[0].axvline(lam0, color='k', linestyle="--", alpha=0.5, zorder=0)

    fitwave = np.array(wavelengths)/(1+redshift)
    fitlag = np.array(lag)/(1+redshift)
    fiterr = np.array(lag_m)
    BLRX = BLR_DF['X'].values/(1+redshift)
    BLRY = BLR_DF['Y'].values
    XRYX = np.linspace(1, max(BLRX),1000)
    
    def alpha(theta, FluxWeight=False, lam0=1):
        # print("alpha")
        
        A, x, beta, y0, logmdotedd = theta
        
        c = 299792458 #m/s
        k = 1.3806452e-23 #J/K
        lam0 = lam0*1e-10#4770*1e-10 #m
        h = 6.62607015e-34 #J/Hz
        G = 6.6743e-11 #Nm^2/kg^2
        M = (10**8.00)*1.989e30 #kg
        sigma = 5.670374419e-8 #W/m^2K^4
        eta = 0.1
        kappa = 1
        mdotedd = 10**logmdotedd
        # mdotedd = 9e-2 #LBol/LEdd; from https://academic.oup.com/mnras/article/419/3/2529/1069495
        Ledd = lumin(M)/1e7 #W
        
        if FluxWeight == True:
            X = 2.49
        else:
            X=4.97
            
        alpha = (1/c)*( X*k*lam0/(h*c) )**(4/3) * ( (G*M/(8*np.pi*sigma)) * (Ledd/(eta*c**2)) * (3+kappa)*mdotedd )**(1/3)
        return alpha/(60*60*24)
    
    
    def timelag(theta, lam=fitwave, lam0=1):
        # print("timelag")
        A, x, beta, y0, logmdotedd = theta
        
        alph = alpha(theta)
        return alph*((lam/lam0)**(beta) - y0), alph
    
    #for calculating the lag at 4770
    def model0(theta, lam=lam0, lam0=lam0, BLRY=BLRY):
        # print(0)
        A, x, beta, y0, logmdotedd = theta
       
        powerlaw = timelag(theta, lam=lam)[0] * ((1-x)/(1-A*x))
        # print("powerlaw0", powerlaw)
        # powerlaw = tau0*((lam/lam0)**(beta)-y0)
        
        mdotedd=10**logmdotedd
        L841 = lumin()*mdotedd
        L5548 = lumin(M = 14.22*1e30*1e7)*0.03
        # L841 = 5.1720e+43
        BLRY = BLRY*(L841/L5548)**0.5
        BLRY = BLRY*(((1-A)*x)/(1-A*x))
        tck = sp.interpolate.splrep(BLRX, BLRY, k=1)
        BLR = sp.interpolate.splev(lam, tck, ext=1)
        total = powerlaw+BLR
        # print("end0", total)
        return total
    
    #for comparison against wavelength
    def model(theta, lam=fitwave, lam0=lam0, idx=0):
        # print("mod")
        total = modelcal(theta)[idx]
        tck = sp.interpolate.splrep(XRYX, total, k=1)
        result = sp.interpolate.splev(lam, tck)
        # print("end mod")
        return result
    
    def model3(theta, lam=fitwave, lam0=lam0):
        A, X, beta, y0, logmdotedd, syserr = theta
        inputs = [A, X, beta, y0, logmdotedd]
        total = modelcal(inputs)[0]
        tck = sp.interpolate.splrep(XRYX, total, k=1)
        result = sp.interpolate.splev(lam, tck)
        return result
    
    #for cal
    def modelcal(theta, lam=fitwave, lam0=lam0, BLRY=BLRY):
        # print("cal")
        A, x, beta, y0, logmdotedd = theta

        reference = model0(theta)
        # print("ref", reference)
        powerlaw = timelag(theta)[0] * ((1-x)/(1-A*x))
        tck = sp.interpolate.splrep(lam, powerlaw)
        attempt = sp.interpolate.splev(XRYX, tck)
        # print("power", attempt)
        mdotedd=10**logmdotedd
        L841 = lumin()*mdotedd
        L5548 = lumin(M = 14.22*1e30*1e7)*0.03
        # print(L841/L5548)
        BLRY = BLRY*(L841/L5548)**0.5
        # print(BLRY)
        BLRY = BLRY*(((1-A)*x)/(1-A*x))
        # print(BLRY)
        tck = sp.interpolate.splrep(BLRX, BLRY,k=1)
        BLRY = sp.interpolate.splev(XRYX, tck, ext=1)
        # print("BLRY", BLRY)
        total = attempt+BLRY - reference
        # print("endcal", total)
        return total, attempt+BLRY
    
    def modelplt(theta, lam=fitwave, lam0=lam0, BLRY=BLRY):
        A, x, beta, y0, logmdotedd = theta

        reference = model0(theta)
        powerlaw, alph = timelag(theta)
        powerlaw = powerlaw * ((1-x)/(1-A*x))
        tck = sp.interpolate.splrep(lam, powerlaw)
        attempt = sp.interpolate.splev(XRYX, tck)
        
        mdotedd=10**logmdotedd
        L841 = lumin()*mdotedd
        L5548 = lumin(M = 14.22*1e30*1e7)*0.03
        BLRY = BLRY*(L841/L5548)**0.5
        BLRY = BLRY*(((1-A)*x)/(1-A*x))# - BLR_corr
        tck = sp.interpolate.splrep(BLRX, BLRY,k=1)
        BLRY = sp.interpolate.splev(XRYX, tck, ext=1)
                                       
        total = attempt+BLRY - reference
        return total, attempt-reference, BLRY-reference, reference, attempt, BLRY
    
    def residual(theta, x, y, yerr):
        # print("res")
        A = theta['A'].value
        X = theta['X'].value
        beta = theta['beta'].value
        y0 = theta['y0'].value
        logmdotedd = theta['logmdotedd'].value
        params = np.array([A, X, beta, y0, logmdotedd])
        res = (y - model(params, x))/yerr
        # print("endres", res)
        return res
    
    def residualplt(theta, x, y, yerr):
        res = (y - model(theta, x))/yerr
        reserr = 1
        return res, reserr
    
    def lnlike(theta, x, y, yerr):
        A, X, beta, y0, logmdotedd, syserr = theta
        return -1/2 * np.sum(((y - model3(theta, x))/yerr)**2 +np.log(yerr**2+syserr**2) )

    def lnprior(theta):
        A, X, beta, y0, logmdotedd, syserr = theta
        if 0 < A < 1 and 0 < X < 1 and 4/3-1e-1 < beta < 4/3+1e-1 and 0 < y0 < 1 and np.log10(2e-2) < logmdotedd < np.log10(18e-2) and 0 < syserr < 5.0:
            return 0.0
        return -np.inf

    def lnprob(theta, x, y, yerr):
        lp = lnprior(theta)
        if not np.isfinite(lp):
            return -np.inf
        return lp + lnlike(theta, x, y, yerr)
 
    def fitter2(data, residual, fit_params, model, labels, nwalkers=500, niter=200000):
        
        result = lmft.minimize(residual, fit_params, args=data)
        # display(result.params)
        keys = [result.params[key].value for key in result.params]
        # print(keys)
        keys.append(1) #add syserr
        # print(keys)
        # params = [keys[0], 10**8*1.989e30, 4/3, 1]
        
        initial = np.array(keys)
        ndim = len(initial)
        p0 = [np.array(initial) + 1e-7 * np.random.randn(ndim) for i in range(nwalkers)]
        
        filename = f"blr_sample.h5"
        backend = emcee.backends.HDFBackend(filename)
        # backend.reset(nwalkers, ndim)
            
        sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, args=data, threads=12, backend=backend)
        if resume == True:
            p0 = None
            pos, prob, state = sampler.run_mcmc(p0, 1, progress=True)
        else:
            p0, _, _ = sampler.run_mcmc(p0, 100, progress=True)
            sampler.reset()
            pos, prob, state = sampler.run_mcmc(p0, niter, progress=True)
        
        samples = sampler.flatchain
        
        # chains(sampler, labels, ndim)
        
        print(sampler.get_autocorr_time())
        
        flat_samples = sampler.get_chain(discard=3*round(np.mean(sampler.get_autocorr_time())), thin=15, flat=True)
        
        # emcee_corner = corner.corner(flat_samples, labels=labels, show_titles=True)
        
        theta_max = samples[np.argmax(sampler.flatlnprobability)]
        
        hp = []
        stdminus = []
        stdplus = []
        for i in range(ndim):
            mcmc = np.percentile(flat_samples[:,i], [16,50,84])
            q = np.diff(mcmc)
            hp.append(mcmc[1])
            stdminus.append(q[0])
            stdplus.append(q[1])
        
        sig = np.sqrt(fiterr**2 + hp[-1])
    
        def sample_walkers(nsamples=10000, flattened_chain=flat_samples):
            models=[]
            powers=[]
            BLRS=[]
            refs=[]
            draw = np.floor(np.random.uniform(0, len(flat_samples), size=nsamples)).astype(int)
            thetas = flattened_chain[draw]
            for i in thetas:
                best_fit_model, power_model, BLR_model, reference, power, BLR = modelplt(i[:-1])
                models.append(best_fit_model)
                powers.append(power_model)
                BLRS.append(BLR_model)
                refs.append(reference)
            values = [models, powers, BLRS, refs]
            spread = [np.std(value, axis=0) for value in values]
            med_model = [np.median(value, axis=0) for value in values]
            return med_model, spread
        
        med_model, spread = sample_walkers()
        
        return theta_max, stdminus, stdplus, med_model, spread, sampler, sampler.get_chain(), flat_samples
    
    
    
    fit_params = lmft.create_params(A={'value':0.5, 'min':0, 'max':1,'vary':True},
                                    X={'value':0.6, 'min':0, 'max':1,'vary':True},
                                    beta={'value':4/3, 'min':4/3-1e-1, 'max':4/3+1e-1,'vary':True},
                                    y0={'value':0.94409927, 'min':0, 'max':1, 'vary':True},
                                    logmdotedd={'value':-1.17, 'min':np.log10(2e-2), 'max':np.log10(18e-2), 'vary':True})
    
    data = (fitwave, fitlag, fiterr)
    
    keys, stdminus, stdplus, med_model, spread, sampler, samples, flat_samples = fitter2(data, residual, fit_params, model3, labels)
    bestmodel = [keys, stdminus, stdplus, med_model, spread, sampler, samples, flat_samples]
    
    best_fit_model, power_model, BLR_model, reference, power, BLR = modelplt(keys[:-1])
    
    allcomp = [best_fit_model, power_model, BLR_model, reference, power, BLR]
    
    upper = []
    lower = []
    for i in range(len(med_model)):
        upper.append(med_model[i]+spread[i])
        lower.append(med_model[i]-spread[i])
    
    ax[0].plot(XRYX, best_fit_model, label="disc+BLR", color="orange", zorder=5)
    ax[0].plot(XRYX, power_model,linestyle="--", color="deeppink", alpha=0.5, zorder=0, label="disc component")
    ax[0].plot(XRYX, BLR_model,linestyle="--", color="dodgerblue", alpha=0.5, zorder=0, label="BLR component")
    # ax[0].axhline(reference, linestyle="dashdot", color="indigo", alpha=0.5, zorder=0, label=fr"$\lambda_0$ offset")
    
    ax[0].fill_between(XRYX, upper[0], lower[0], color="orange", alpha=0.2, zorder=0, label=fr"$1\sigma$ spread")
    ax[0].fill_between(XRYX, upper[1], lower[1], color="deeppink", alpha=0.2, zorder=0, label=fr"$1\sigma$ spread")
    ax[0].fill_between(XRYX, upper[2], lower[2], color="dodgerblue", alpha=0.2, zorder=0, label=fr"$1\sigma$ spread")
    # ax[0].fill_between(XRYX, upper[3], lower[3], color="indigo", alpha=0.2, zorder=0, label=fr"$1\sigma$ spread")
    
    ax[1].set_ylabel(fr'$\chi$')
    res, reserr = residualplt(keys[:-1], fitwave, fitlag, fiterr)
    ax[1].errorbar(fitwave, res,yerr=reserr, marker='o', color="k", ls='none')
    ax[1].axhline(0, color='k', linestyle="--", alpha=0.5, zorder=0)
    
    lines, labels = ax[0].get_legend_handles_labels()
    
    ax[0].legend([lines[0], lines[1], lines[2], lines[3], (lines[4], lines[5], lines[6])], 
                 labels, handler_map={tuple: HandlerTuple(ndivide=None)})
    ax[0].yaxis.grid(True, which='minor')
    plt.minorticks_on()
    
    df = pd.DataFrame({'keys':keys, 'stdplus':stdplus, 'stdminus':stdminus})
    df.to_pickle(f"BLR_model.pkl")
    
    df = pd.DataFrame({'med_model':med_model, 'spread':spread})
    df.to_pickle(f"BLR_uncertainty.pkl")
    
    values = ['A', 'X', 'beta', 'y0', 'logmdotedd']
    
    df = pd.DataFrame({'syserr_step':samples[:, :, -1][0], 'syserr': samples[:, :, -1][1]})
    for i in range(len(values)):
        df.loc[:, fr'{values[i]}_step'] = samples[:, :, i][0]
        df.loc[:, fr'{values[i]}'] = samples[:, :, i][1]
    df.to_pickle(f"BLR_samples.pkl")
    
    df = pd.DataFrame({'syserr':flat_samples[:,-1]})
    for i in range(len(values)):
        df.loc[:, fr'{values[i]}'] = flat_samples[:,i]
    df.to_pickle(f"BLR_flat_samples.pkl")
    
    df = pd.DataFrame({'best_fit':best_fit_model, 'power_model':power_model, 'BLR_model':BLR_model, 'power':power, 'BLR': BLR})
    df.to_pickle(f"BLR_components.pkl")
    
    df = pd.DataFrame({'res':res, 'reserr':reserr})
    df.to_pickle(f"BLR_residuals.pkl")
    
    plt.savefig("BLR_fig.png", dpi=300)
    
    return bestmodel, allcomp, [res, reserr]

outputdir = "/home/jovyan/PyROA2/"
blr, comp, residuals = BLRLagFalse(filters,delay_ref,outputdir=outputdir,
                                burnin=burnin,
                                band_colors=band_colors,
                                wavelengths=waves,redshift=redshift,
                                dataframe = DF_FCCF, resume = True)

In [None]:
FILTERS = ["UVW2", "gp1", "gp", "V", "rp", "ip", "zs"]
filters = FILTERS
waves = [1894, 4770, 5510, 6231, 7625, 9134]
redshift = 0.0364
band_colors=['#0652DD','#1289A7','#006266','#A3CB38','orange','#EE5A24','brown']
delay_ref = 'gp1'
outputdir = "/home/jovyan/PyROA2/"
# AS5101/samples_flat.obj
burnin=150000
def BLRLagPlot(filters,delay_ref,wavelengths,
                burnin=0,samples_file='samples_flat.obj',
                outputdir = './',
                band_colors = None,
                redshift=0.0,
                savefig=True,figname=None,
                dataframe=None, name=None, add=""):
    
    if outputdir[-1] != '/': outputdir += '/'
    file = open(outputdir+samples_file,'rb')
    samples = pickle.load(file)[burnin:]

    ss = np.where(np.array(filters) == delay_ref)[0][0]
    #print(ss)
    labels = []
    for i in range(len(filters)):
        for j in ["A", "B",r"$\tau$", r"$\sigma$"]:
            labels.append(j+r'$_{'+filters[i]+r'}$')
    labels.append(r'$\Delta$')
    all_labels = labels.copy()
    del labels[ss*4+2]

    # To get ONLY lags
    shifter = 2

    list_only = []
    mm = 0
    ndim = len(filters)
    for i in range(ndim):
        if i != ss:
            list_only.append(i*4+shifter+mm)
        if i == ss:
            mm = -1
    # Get the 
    lag,lag_m,lag_p = np.zeros(ndim-1),np.zeros(ndim-1),np.zeros(ndim-1)
    for j,i in enumerate(list_only):
        #print(i)
        q50 = np.percentile(samples[:,i],50)
        q84 = np.percentile(samples[:,i],84)
        q16 = np.percentile(samples[:,i],16)
        lag[j] = q50
        lag_m[j] = q50-q16
        lag_p[j] = q84-q50
    fig, ax = plt.subplots(2, figsize=(15,10), height_ratios=[5,1],sharex=True)
    plt.subplots_adjust(wspace=0, hspace=0)

    ax[0].axhline(y=0,ls='--',alpha=0.5,color='k')

    if band_colors == None: band_colors = 'k'*7
    xerr = [584.89/2, 1262.68/2, 840/2, 1149.52/2, 1238.95/2, 994.39/2]
    mm = 0
    for i in range(lag.size):        
        ax[0].errorbar(wavelengths[i]/(1+redshift),lag[i]/(1+redshift),
                    yerr=lag_m[i], xerr=xerr[i],marker='o',
                    color='k', zorder=100)

    if redshift > 0:
        ax[1].set_xlabel(r'Rest Wavelength ($\mathrm{\AA})$')
        ax[0].set_ylabel(r'$\tau_{\rm rest}$ (day)')
    else:
        ax[1].set_xlabel(r'Observed Wavelength ($\mathrm{\AA})$')
        ax[0].set_ylabel(r'$\tau$ / day')
        
    delta = round(np.mean([np.median(samples[:, -1]), np.mean(samples[:, -1]), sp.stats.mode(samples[:, -1])[0]]))
        
    ax[0].lines[-1].set_label(fr'lags, $\Delta$={round(delta,2)}')
    
    lam0 = dataframe.loc[(dataframe['Filter'] == 'gp')]['Wavelength'].values[0]/(1+redshift)
    
    ax[0].axvline(lam0, color='k', linestyle="--", alpha=0.5, zorder=0)

    fitwave = np.array(wavelengths)/(1+redshift)
    fitlag = np.array(lag)/(1+redshift)
    fiterr = np.array(lag_m)
    BLRX = BLR_DF['X'].values/(1+redshift)
    BLRY = BLR_DF['Y'].values
    XRYX = np.linspace(1, max(BLRX),1000)
    
    df = pd.read_pickle(f"{name}_model{add}.pkl")
    keys = df['keys'].values
    stdplus = df['stdplus'].values
    stdminus=df['stdminus'].values
    
    df = pd.read_pickle(f"{name}_uncertainty{add}.pkl")
    med_model = df['med_model'].values
    spread = df['spread'].values
    
    values = ['A', 'X', 'beta', 'y0', 'logmdotedd', 'syserr']
    df = pd.read_pickle(f"{name}_flat_samples{add}.pkl")
    flat_samples = np.array([df[f'{v}'] for v in values]).T
    
    df = pd.read_pickle(f"{name}_samples{add}.pkl")
    samples = []
    for i in range(len(values)-1):
        v=values[i]
        smpl = np.array(df[f'{v}_step'], df[f'{v}'])
        samples.append(smpl)
    samples.append(np.array(df['step'], df['syserr']))
    
    df = pd.read_pickle(f"{name}_components{add}.pkl")
    best_fit_model = df['best_fit'].values
    power_model = df['power_model'].values
    BLR_model = df['BLR_model'].values
    power = df['power'].values
    BLR = df['BLR'].values
    
    df = pd.read_pickle(f"{name}_residuals{add}.pkl")
    res = df['res'].values
    reserr = df['reserr'].values   
    
    upper = []
    lower = []
    for i in range(len(med_model)):
        upper.append(med_model[i]+spread[i])
        lower.append(med_model[i]-spread[i])
    
    ax[0].plot(XRYX, best_fit_model, label="disc+BLR", color="orange", zorder=5)
    ax[0].plot(XRYX, power_model,linestyle="--", color="deeppink", alpha=0.5, zorder=0, label="disc component")
    ax[0].plot(XRYX, BLR_model,linestyle="--", color="dodgerblue", alpha=0.5, zorder=0, label="BLR component")
    # ax[0].axhline(reference, linestyle="dashdot", color="indigo", alpha=0.5, zorder=0, label=fr"$\lambda_0$ offset")
    
    ax[0].fill_between(XRYX, upper[0], lower[0], color="orange", alpha=0.2, zorder=0, label=fr"$1\sigma$ spread")
    ax[0].fill_between(XRYX, upper[1], lower[1], color="deeppink", alpha=0.2, zorder=0, label=fr"$1\sigma$ spread")
    ax[0].fill_between(XRYX, upper[2], lower[2], color="dodgerblue", alpha=0.2, zorder=0, label=fr"$1\sigma$ spread")
    # ax[0].fill_between(XRYX, upper[3], lower[3], color="indigo", alpha=0.2, zorder=0, label=fr"$1\sigma$ spread")
    
    ax[1].set_ylabel(fr'$\chi$')
    ax[1].errorbar(fitwave, res,yerr=reserr, marker='o', color="k", ls='none')
    ax[1].axhline(0, color='k', linestyle="--", alpha=0.5, zorder=0)
    
    lines, labels = ax[0].get_legend_handles_labels()
    
    ax[0].legend([lines[0], lines[1], lines[2], lines[3], (lines[4], lines[5], lines[6])], 
                 labels, handler_map={tuple: HandlerTuple(ndivide=None)}, ncol=5, bbox_to_anchor=(0,0))
    ax[0].yaxis.grid(True, which='minor')
    plt.minorticks_on()
    
    bestmodel = [keys, stdminus, stdplus, med_model, spread, samples, flat_samples]
    allcomp = [best_fit_model, power_model, BLR_model, power, BLR]
    
    plt.savefig("BLR_fig.png", dpi=300)
    
    labels=['A', 'x', r'$\beta$', r'$y_0$', r'$\log{\dot{M}_{\rm Edd}}$', 'syserr']
    
    return bestmodel, allcomp, [res, reserr]

outputdir = "/home/jovyan/PyROA2/"
blr, comp, residuals = BLRLagPlot(filters,delay_ref,outputdir=outputdir,
                                burnin=burnin,
                                band_colors=band_colors,
                                wavelengths=waves,redshift=redshift,
                                dataframe = DF_FCCF, name="BLR")

In [None]:
def mathdisplay(keys, stdminus, stdplus, labels):
    for i in range(len(keys)):
        txt = "\mathrm{{{3}}} = {0:.3f}_{{-{1:.3f}}}^{{{2:.3f}}}"
        txt = txt.format(keys[i], stdminus[i], stdplus[i], labels[i])
        display(Math(txt))
        
    print(keys[4]/(10**8 * 1.989e30))
    print(stdminus[4]/(10**8 * 1.989e30))
    print(stdplus[4]/(10**8 * 1.989e30))

In [None]:
def cornerplot(name, labels, add="", burnin=6540000):
    values = ['A', 'X', 'beta', 'y0', 'logmdotedd']
    df = pd.read_pickle(f"{name}_flat_samples{add}.pkl")
    flat_samples = np.array([df[f'{v}'] for v in values]).T[burnin:]
    emcee_corner = corner.corner(flat_samples, labels=labels, show_titles=True, quantiles=(0.16, 0.5, 0.84), title_kwargs={'fontsize':15}, levels=(0.5, 1, 2, 3))
    plt.savefig(f"{name}_corner{add}.png", dpi=300)

labels=['A', 'X', r'$\beta$', r'$y_0$', r'$\log{\dot{M}_{\rm Edd}}$']
cornerplot("BLR", labels)

In [None]:
def BLRtau0(filters,delay_ref,wavelengths,
                burnin=0,samples_file='samples_flat.obj',
                outputdir = './',
                band_colors = None,
                redshift=0.0,
                savefig=True,figname=None,
                dataframe=None, flats=None, FluxWeight=True, labels=None):
    
    if outputdir[-1] != '/': outputdir += '/'
    file = open(outputdir+samples_file,'rb')
    samples = pickle.load(file)[burnin:]

    ss = np.where(np.array(filters) == delay_ref)[0][0]
    
    # To get ONLY lags
    shifter = 2

    list_only = []
    mm = 0
    ndim = len(filters)
    for i in range(ndim):
        if i != ss:
            list_only.append(i*4+shifter+mm)
        if i == ss:
            mm = -1
    # Get the 
    lag,lag_m,lag_p = np.zeros(ndim-1),np.zeros(ndim-1),np.zeros(ndim-1)
    for j,i in enumerate(list_only):
        #print(i)
        q50 = np.percentile(samples[:,i],50)
        q84 = np.percentile(samples[:,i],84)
        q16 = np.percentile(samples[:,i],16)
        lag[j] = q50
        lag_m[j] = q50-q16
        lag_p[j] = q84-q50
        
    xerr = [584.89/2, 1262.68/2, 840/2, 1149.52/2, 1238.95/2, 994.39/2]
    
    delta = np.mean([np.median(samples[:, -1]), np.mean(samples[:, -1]), sp.stats.mode(samples[:, -1])[0]])
    
    lam0 = dataframe.loc[(dataframe['Filter'] == 'gp')]['Wavelength'].values[0]/(1+redshift)

    fitwave = np.array(wavelengths)/(1+redshift)
    fitlag = np.array(lag)/(1+redshift)
    fiterr = np.array(lag_m)
    BLRX = BLR_DF['X'].values/(1+redshift)
    BLRY = BLR_DF['Y'].values
    XRYX = np.linspace(1, max(BLRX),1000)

    def alpha(theta, FluxWeight=True, lam0=1):
        A, x, beta, y0, logmdotedd = theta
        
        c = 299792458 #m/s
        k = 1.3806452e-23 #J/K
        lam0 = lam0*1e-10#4770*1e-10 #m
        h = 6.62607015e-34 #J/Hz
        G = 6.6743e-11 #Nm^2/kg^2
        M = (10**8.00)*1.989e30 #kg
        sigma = 5.670374419e-8 #W/m^2K^4
        eta = 0.1
        kappa = 1
        mdotedd = 10**logmdotedd
        Ledd = lumin(M)/1e7 #W
        
        if FluxWeight == True:
            X = 2.49
        else:
            X=4.97
            
        alpha = (1/c)*( X*k*lam0/(h*c) )**(4/3) * ( (G*M/(8*np.pi*sigma)) * (Ledd/(eta*c**2)) * (3+kappa)*mdotedd )**(1/3)
        return alpha/(60*60*24)
    
    
    def timelag(theta, lam=fitwave, lam0=1, FluxWeight=True):
        A, x, beta, y0, logmdotedd = theta
        
        alph = alpha(theta, FluxWeight=FluxWeight)
        return alph*((lam/lam0)**(4/3) - 1), alph
    
    def newmodel(theta, FluxWeight=True, lam=fitwave, BLRY=BLRY, callam1= 0, callam2=lam0):
        A, x, beta, y0, logmdotedd = theta
        
        pwl, alph = timelag(theta, FluxWeight=FluxWeight)
        
        tck = sp.interpolate.splrep(lam, pwl)
        pwl1 = sp.interpolate.splev(callam1, tck)
        pwl2 = sp.interpolate.splev(callam2, tck)
        # print(pwl1, pwl2)
        return pwl2-pwl1
    
    def sample_walkers(model, nsamples=10000, flattened_chain=None, FluxWeight=True):
        models=[]
        draw = np.floor(np.random.uniform(0, len(flat_samples), size=nsamples)).astype(int)
        thetas = flattened_chain[draw]
        for i in thetas:
            mod = model(i)
            models.append(mod)
        spread = np.std(models, axis=0)
        med_model = np.median(models, axis=0)
        return med_model, spread
    
    flat_samples = np.array([flats[f'{l}'] for l in labels]).T
    
    med_model, spread = sample_walkers(newmodel, flattened_chain=flat_samples, FluxWeight=FluxWeight)
    
    print('%7.3f +- %7.3f days'%(med, spread))
    
    return med_model, spread

In [None]:
flats = pd.read_pickle("BLR_flat_samples.pkl")
labels = ['A', 'X', 'beta', 'y0', 'logmdotedd']

In [None]:
med, spread = BLRtau0(filters,delay_ref,outputdir=outputdir,
                                burnin=burnin,
                                band_colors=band_colors,
                                wavelengths=waves,redshift=redshift,
                                dataframe = DF_FCCF, 
                      flats=flats, labels=labels, FluxWeight=False)

## Frequency-resolved lag spectrum

In [None]:
outputdir = ["/home/jovyan/PyROAF/", "/home/jovyan/PyROA2/","/home/jovyan/PyROA5/", "/home/jovyan/PyROA10/",  "/home/jovyan/PyROAD20/", "/home/jovyan/PyROA40"]
def FreqSpectrum(filters,delay_ref,wavelengths,
                burnin=0,samples_file='samples_flat.obj',
                outputdirs = './',
                band_colors = None,
                redshift=0.0,
                savefig=True,figname=None,
                dataframe=None, database=None):
    fig, ax = plt.subplots(figsize=(10,7))
    for s in range(len(outputdirs)):
        outputdir=outputdirs[s]
        if outputdir[-1] != '/': outputdir += '/'
        file = open(outputdir+samples_file,'rb')
        samples = pickle.load(file)[burnin:]
        deltas = np.mean([np.median(samples[:, -1]), np.mean(samples[:, -1]), sp.stats.mode(samples[:, -1])[0]])
        delta = np.array([fr'{round(deltas,2)}', 2, 5, 10, 20, 40])[s]
        color=np.array(['k', 'purple','r','b', 'orange', 'green'])[s]
        
        ss = np.where(np.array(filters) == delay_ref)[0][0]
        #print(ss)
        labels = []
        for i in range(len(filters)):
            for j in ["A", "B",r"$\tau$", r"$\sigma$"]:
                labels.append(j+r'$_{'+filters[i]+r'}$')
        labels.append(r'$\Delta$')
        all_labels = labels.copy()
        del labels[ss*4+2]

        # To get ONLY lags
        shifter = 2

        list_only = []
        mm = 0
        ndim = len(filters)
        for i in range(ndim):
            if i != ss:
                list_only.append(i*4+shifter+mm)
            if i == ss:
                mm = -1
        # Get the 
        lag,lag_m,lag_p = np.zeros(ndim-1),np.zeros(ndim-1),np.zeros(ndim-1)
        for j,i in enumerate(list_only):
            #print(i)
            q50 = np.percentile(samples[:,i],50)
            q84 = np.percentile(samples[:,i],84)
            q16 = np.percentile(samples[:,i],16)
            lag[j] = q50
            lag_m[j] = q50-q16
            lag_p[j] = q84-q50
        
        # ax = fig.add_subplot(111)

        plt.axhline(y=0,ls='--',alpha=0.5,color='k')

        if band_colors == None: band_colors = 'k'*7

        mm = 0
        wavelengths = np.array(wavelengths)
        # for i in range(lag.size):        
        plt.plot(wavelengths/(1+redshift),lag/(1+redshift),color=color)
        plt.fill_between(wavelengths/(1+redshift), lag/(1+redshift)-lag_m, lag/(1+redshift)+lag_m,color=color, alpha=0.3) 
            # print(wavelengths[i])

        if redshift > 0:
            plt.xlabel(r'Rest Wavelength ($\mathrm{\AA}$)')
            plt.ylabel(r'$\tau_{\rm rest}$ (day)')
        else:
            plt.xlabel(r'Observed Wavelength / $\mathrm{\AA}$')
            plt.ylabel(r'$\tau$ / day')
        if outputdir == "/home/jovyan/PyROAF/":
            ax.lines[-1].set_label(fr'Free parameter $\Delta \simeq$ {delta}')
        else:
            ax.lines[-1].set_label(fr'$\Delta$ = {delta}')
        # else:
            # ax.lines[-1].set_label(fr'PyROA')
        
    lam0 = dataframe.loc[(dataframe['Filter'] == 'gp')]['Wavelength'].values[0]
    
    ax.axvline(lam0, color='k', linestyle="--", alpha=0.5, zorder=0)
        
    plt.legend(fontsize=12, ncol=2)
    plt.savefig("freqspec.png", dpi=300)

In [None]:
FreqSpectrum(filters,delay_ref,outputdirs=outputdir,
                burnin=burnin,
                band_colors=band_colors,
                wavelengths=waves,redshift=redshift,
                dataframe = DF_FCCF)

# XSPEC plots

## Importing files

In [None]:
def df_no(df, names):
    for name in names:
        df.loc[df[fr'{name}'] == "NO", fr'{name}'] = float("nan")
    df = df.astype(np.float32)
    return df

In [None]:
names = ['X','X_err','Y','Y_err','total', 'relagn', 'relxill']
names_delchi = ['X','X_err','Y','Y_err']
modelname = ["relagn", "relxill"]
df = pd.read_table('2agnxill_data.qdp',skiprows=3,names=names, delimiter=' ')
df_delchi = pd.read_table('2agnxill02_delchi.qdp', names =names_delchi, skiprows=3, delimiter=" ")

df = df_no(df, names)
df_delchi=df_no(df_delchi, names_delchi)

names2 = ['X','X_err','Y','relagn', 'relxill']
modelname2 = ["relagn", "relxill"]
df2 = pd.read_table('2agnxill02_nufnu.qdp',skiprows=3,names=names2, delimiter=' ')
df2 = df_no(df2, names2)

## Plotting photon counts against energy and XSPEC model

In [None]:
def df_plot(df, modelname, delchi, X = 'Energy (keV)', Y = 'Photon (counts/s/keV)'):
    fig, ax = plt.subplots(2, figsize=(10,8), height_ratios=[5,1],sharex=True)
    plt.subplots_adjust(wspace=0, hspace=0)
    # # Plot using Matplotlib:
    ax[0].errorbar(df['X'].values, df['Y'].values, xerr=abs(df['X_err'].values),yerr=abs(df['Y_err'].values),fmt='o',label='data', zorder=0, color='k')
    ax[0].plot(df['X'].values, df['total'].values, color='red',label=r'2016-2017 XSPEC model',ls="--")
    color=['red', 'dodgerblue']
    linestyle=["dashed", "dotted"]
    for j in range(0):
        if j == 0:
            i = 0
        else:
            i = 4
        ax[0].plot(df['X'].values[i:], df[f'{modelname[j]}'].values[i:],label=f'{modelname[j]}', color=color[j], linestyle=linestyle[j])
    plt.xlabel(fr'{X}')
    ax[0].set_ylabel(fr'{Y}')
    ax[0].set_xscale("log")
    ax[0].set_yscale("log")
    ax[0].yaxis.set_ticks_position('left')
    ax[0].set_ylim((2e-5,1e4))
    ax[1].errorbar(delchi['X'].values, delchi['Y'].values, xerr=delchi['X_err'].values, yerr=delchi['Y_err'].values, fmt='.', color='k', alpha=0.7)
    ax[1].set_ylabel(fr'$\chi$')
    ax[1].axhline(0, ls="--", zorder=0, color="grey")
    ax[1].yaxis.set_ticks_position('left') 
    # ax[1].set_ylim(
    # ax.grid()
    ax[0].legend()
    ax[0].xaxis.set_ticks_position('both') 
    ax1 = ax[0].twiny()
    ax1.set_xticks(ax[0].get_xticks())
    ax1.set_xbound(ax[0].get_xbound())
    ax1.set_xscale('log')
    ax1.set_xticklabels([round(12.398/x) for x in ax[0].get_xticks()])
    
    ax[0].minorticks_off()
    ax[1].minorticks_off()
    ax1.minorticks_off()
    
    
    ax1.set_xlabel(r"Wavelength $\lambda$ ($\mathrm{\AA}$)")
    plt.show()
    fig.savefig("specdatares.png", dpi=300)
df_plot(df, modelname, df_delchi)

## Plotting XSPEC SED model

In [None]:
def A_energy(A):
        return 12.398425/A
    
#for converting nufnu xspec plots and plotting model components
def plotter(df2, modelname):
    energies = A_energy(df2['X'].values)
    fig, ax = plt.subplots(figsize=(10,6))
    # # Plot using Matplotlib:
    ax.plot(energies, df2['Y'].values*1e23, color='plum',label='total model') #Jy = 10−23 erg⋅s−1⋅cm−2⋅Hz−1
    color=['red', 'dodgerblue']
    for j in range(2):
        ax.plot(energies, df2[f'{modelname[j]}'].values*1e23,label=f'{modelname[j]}', color=color[j], linestyle="dotted")

    ax.set_xlabel('Energy (keV)')
    ax.set_ylabel(r'$\nu \mathrm{F}_\nu$')
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_ylim((1e10,1e13))
    ax.set_xlim((1e-3,1e3))
    ax.grid()
    ax.legend()
    plt.show()
    return energies, df2['Y'].values*1e23
energies, Y = plotter(df2, modelname2)

## Plotting XSPEC SED model against models from Mehdipour et al. 2023

In [None]:
def model_plot(dfs, models, colors, X = 'Energy (keV)', Y = fr'$\nu F_\nu$ (Jy)'):
    fig, ax = plt.subplots(figsize=(10,8))
    # # Plot using Matplotlib:
    for i in range(len(dfs)):
        df = dfs[i]
        ax.plot(df['X'].values, df['Y'].values,label=models[i], color=colors[i])
    
    ax.set_xlabel(fr'{X}')
    ax.set_ylabel(fr'{Y}')
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.yaxis.set_ticks_position('left')
    ax.legend()
    ax.xaxis.set_ticks_position('both') 
    ax.set_ylim((1e11,1e13))
    ax.set_xlim((1e-3,1e2))
    ax1 = ax.twiny()
    ax1.set_xticks(ax.get_xticks())
    ax1.set_xbound(ax.get_xbound())
    ax1.set_xscale('log')
    ax1.set_xticklabels([round(12.398/x) for x in ax.get_xticks()])
    
    ax.minorticks_off()
    ax1.minorticks_off()
    ax1.set_xlabel(r"Wavelength $\lambda$ ($\mathrm{\AA}$)")
    plt.show()
    fig.savefig("sedcomp.png", dpi=300)
    
red = pd.read_csv("red.csv", names=['X', 'Y'])
blue = pd.read_csv("blue.csv", names=['X', 'Y'])
dfmy = pd.DataFrame({'X':energies, 'Y':Y})
model_plot([blue, red, dfmy], ['2001', '2022', '2016 March-2017 August'], ['blue', 'red', 'green'])

# RELAGN components

In [None]:
dagn=relagn(M=1e8, dist=161, log_mdot=-1.17)

#getting total AGN SED
Lnu_rel = dagn.get_totSED(rel=True)

#getting components
Lrel_dsc = dagn.get_DiscComponent()
Lrel_wrm = dagn.get_WarmComponent()
Lrel_hot = dagn.get_HotComponent()

In [None]:
fig = plt.figure(figsize=(11, 8))
ax = fig.add_subplot(111)

nu = dagn.nu_obs #frequency grid
ang=3e18/nu
ax.loglog(ang, nu*Lrel_dsc, ls='-.', color='deeppink', label="accretion disk")
ax.loglog(ang, nu*Lrel_wrm, ls='-.', color='indigo', label="warm Comptonising component")
ax.loglog(ang, nu*Lrel_hot, ls='-.', color='midnightblue', label="hot Comptonising component")

ax.loglog(ang, nu*Lnu_rel, color='k', label="total SED")


ax.set_ylim(max(nu*Lnu_nr)*1e-2, max(nu*Lnu_nr)*2)

ax.set_xlabel(r'Wavelength   ($\mathrm{\AA}$)')
ax.set_ylabel(r'$\nu F_{\nu}$   (ergs/s)')
plt.legend()
plt.savefig("example.png", dpi=300)
plt.show()

# XSPEC codes

for running XSPEC fit:

In [None]:
method leven 500 0.01
abund angr
xsect vern
cosmo 70 0 0.73
xset delta 0.01
xset KYRH     1.14106727
xset KYRIN     4.02531052
xset KYRMS     1.45449758
systematic 0
model  phabs(relagn + relxill)
      0.0310315      0.001          0          0     100000      1e+06
          1e+08         -1          1          1      1e+10      1e+10
          161.2         -1     0.0001     0.0001       1000       1000
       -1.18768       0.01        -10        -10          2          2
           0.99         -1          0          0      0.998      0.998
            0.5         -1       0.09       0.09      0.998      0.998
            100         -1         10         10        300        300
       0.114629      0.001       0.01       0.01          1          1
        1.82446       0.01        1.3        1.3          3          3
        2.39367       0.01          2          2          5          5
        4.33208        0.1          1          1        500        500
              6       -0.1          6          6        500        500
             -1         -1         -1         -1          7          7
/
        9.57348          1          6          6         10         10
         0.0364         -1          0          0          1          1
              1         -1          0          0      1e+20      1e+24
              3       -0.1        -10        -10         10         10
              3       -0.1        -10        -10         10         10
/
= p5
= acos(p6)/3.141592*180
             -1         -1       -100       -100         -1         -1
        2.89228        0.1          1          1        400       1000
= p16
= p9
        1.69563       0.01          0          0        4.7        4.7
              1         -1        0.5        0.5         10         10
            300       -0.1          5          5       1000       1000
             -1         -1      -1000          0         10       1000
    0.000140173       0.01          0          0      1e+20      1e+24
= p1
= p2
= p3
= p4
= p5
= p6
= p7
= p8
= p9
= p10
= p11
= p12
= p13
= p14
= p15
= p16
= p17
= p18
= p19
= p20
= p21
= p22
= p23
= p24
= p25
= p26
= p27
= p28
= p29
= p30
= p31
newpar 14 = p23
newpar 20 = p23
bayes off

for plotting models:

In [None]:
method leven 500 0.01
abund angr
xsect vern
cosmo 70 0 0.73
xset delta 0.01
xset KYRH     1.14106727
xset KYRIN     4.02531052
xset KYRMS     1.45449758
systematic 0
model  (relagn + relxill)
          1e+08         -1          1          1      1e+10      1e+10
          161.2         -1     0.0001     0.0001       1000       1000
       -1.18768       0.01        -10        -10          2          2
           0.99         -1          0          0      0.998      0.998
            0.5         -1       0.09       0.09      0.998      0.998
            100         -1         10         10        300        300
       0.114629      0.001       0.01       0.01          1          1
        1.82446       0.01        1.3        1.3          3          3
        2.39367       0.01          2          2          5          5
        4.33208        0.1          1          1        500        500
              6       -0.1          6          6        500        500
             -1         -1         -1         -1          7          7
/
        9.57348          1          6          6         10         10
         0.0364         -1          0          0          1          1
              1         -1          0          0      1e+20      1e+24
              3       -0.1        -10        -10         10         10
              3       -0.1        -10        -10         10         10
/
= p4
= acos(p5)/3.141592*180
             -1         -1       -100       -100         -1         -1
        2.89228        0.1          1          1        400       1000
= p15
= p8
        1.69563       0.01          0          0        4.7        4.7
              1         -1        0.5        0.5         10         10
            300       -0.1          5          5       1000       1000
             -1         -1      -1000          0         10       1000
    0.000140173       0.01          0          0      1e+20      1e+24
= p1
= p2
= p3
= p4
= p5
= p6
= p7
= p8
= p9
= p10
= p11
= p12
= p13
= p14
= p15
= p16
= p17
= p18
= p19
= p20
= p21
= p22
= p23
= p24
= p25
= p26
= p27
= p28
= p29
= p30
newpar 13 = p22
newpar 19 = p22
bayes off