In [None]:
import emcee
import numpy as np
import pyimfit
import pathlib
import cProfile
import time
import pickle
import matplotlib.pyplot as plt
from astropy.io import fits
import tqdm
from multiprocessing import Pool
import corner

plt.rcParams['image.cmap'] = 'Blues'

imageFile = "agn.fits"
imageAGN = fits.getdata(imageFile)

epsf = fits.getdata("../psfConstruction/epsf2.fits")
epsf_star = fits.getdata("../psfConstruction/star0.fits")

In [None]:
def getFits():
    """get fitters and best fit values"""
    with open('fitResults/J1215+1344_fit_.pkl', 'rb') as file:
        d = pickle.load(file)
    # get model to create fitters
    models_n1 = d['fitConfig_n1']
    models_n4 = d['fitConfig_n4']
    psfOsamp = pyimfit.MakePsfOversampler(epsf, 4, (0,100,0,100))
    osampleList = [psfOsamp]
    fitters_n1, fitters_n4 =[], []
    for fitters, models in zip([fitters_n1, fitters_n4],[models_n1,models_n4]):
        for i in tqdm.tqdm(range(len(models)), desc="Fitting Models"):
            imfit_fitter = pyimfit.Imfit(models[i],psf=epsf)
            imfit_fitter.loadData(imageAGN, psf_oversampling_list=osampleList, gain=9.942e-1, read_noise=0.22, original_sky=15.683)
            fitters.append(imfit_fitter)
    return fitters_n1, fitters_n4, d['bestfit_n1'], d['bestfit_n4']

fitters_n1, fitters_n4, bestfits_n1, bestfits_n4 = getFits()

In [None]:
with open('fitResults/J1215+1344_fit_.pkl', 'rb') as file:
        d = pickle.load(file)
    # get model to create fitters
models_n1 = d['fitConfig_n1']
models_n4 = d['fitConfig_n4']
psfOsamp = pyimfit.MakePsfOversampler(epsf, 4, (0,100,0,100))
osampleList = [psfOsamp]

imfit_fitter = pyimfit.Imfit(models_n1[1],psf=epsf)
imfit_fitter.loadData(imageAGN, psf_oversampling_list=osampleList, gain=9.942e-1, read_noise=0.22, original_sky=15.683)

In [None]:
m=1
parameterLimits = fitters_n1[m].getParameterLimits()
names = fitters_n1[m].numberedParameterNames
rmind = [i for i, name in zip(range(len(names)),names) if "n_" in name]
parameterLimits = [element for indx, element in enumerate(parameterLimits) if indx not in rmind]
parameterLimits = [(0,100000) if e is None else e for e in parameterLimits]
parameterLimits

In [None]:
def get_rm_inds(fitter):
    names = fitter.numberedParameterNames
    rm_inds = [i for i, name in zip(range(len(names)),names) if "n_" in name]
    return rm_inds

rm_inds = [get_rm_inds(fitter) for fitter in fitters_n1]
rm_inds

# check prior

In [None]:

def lnPrior_func(params,imfitter,rmind):
    parameterLimits = imfitter.getParameterLimits()
    parameterLimits = [element for indx, element in enumerate(parameterLimits) if indx not in rmind]
    parameterLimits = [(0,100000) if e is None else e for e in parameterLimits]
    nParams = len(params)
    for i in range(nParams):
        if params[i] < parameterLimits[i][0] or params[i] > parameterLimits[i][1]:
            print(parameterLimits[i], params[i])
            return  -np.inf
    return 0.0

m = 1
p_bestfit = bestfits_n1[m]
p_bestfit = np.delete(p_bestfit, rm_inds[m])
ndims, nwalkers = len(p_bestfit), 50
initial_pos = [p_bestfit + 0.001*np.random.randn(ndims) for i in range(nwalkers)]

for i in range(20):
    print(i, lnPrior_func(initial_pos[i],fitters_n1[m],rm_inds[m]) )


# check posterior

In [None]:
def lnPosterior_func_chi(params, imfitter, p_bestfit, rmind):
    lnPrior = lnPrior_func(params,imfitter,rmind)
    if not np.isfinite(lnPrior):
        return -np.inf
    params = np.insert(p_bestfit,rmind,1)
    #compute chi square likelihood
    newIm = imfitter.getModelImage(newParameters=params)
    chisquared = np.sum(((newIm - imageAGN)**2) / (imageAGN*0.01)**2)
    lnLikelihood = -0.5*chisquared
    return lnPrior + lnLikelihood

def lnPosterior_pf(params, imfitter, lnPrior_func, rmInd):
    lnPrior = lnPrior_func(params, imfitter, rmInd)
    if not np.isfinite(lnPrior):
        return -np.inf
    params = np.insert(params,rmInd,1)
    
    lnLikelihood = -0.5 * imfitter.computeFitStatistic(params)
    return lnPrior + lnLikelihood

m=1
p_bestfit = bestfits_n1[m]
p_bestfit = np.delete(p_bestfit, rm_inds[m])
ndims, nwalkers = len(p_bestfit), 50
initial_pos = [p_bestfit + 0.001*np.random.randn(ndims) for i in range(nwalkers)]

[lnPosterior_pf(initial_pos[i],fitters_n1[m],lnPrior_func,rm_inds[m]) for i in range(10)]


In [None]:
[print(fitters_n1[1].numberedParameterNames[i], bestfits_n1[1][i]) for i in range(16)]

In [None]:
def main():
    [lnPosterior_pf(initial_pos[i],fitters_n1[m],lnPrior_func,rm_inds[m]) for i in range(nwalkers)]
    
cProfile.run("main()", sort="cumulative")

In [None]:
pos1 = np.array(initial_pos)+0.5
def main1():
    [lnPosterior_pf(pos1[i],fitters_n1[m],lnPrior_func,rm_inds[m]) for i in range(nwalkers)]
    
cProfile.run("main1()", sort="cumulative")

In [None]:
def run_emcee(p_bestfit, fitter,rmInd,numsteps):
    p_bestfit = np.delete(p_bestfit, rmInd)
    ndims, nwalkers = len(p_bestfit), 50
    initial_pos = [p_bestfit + 0.001*np.random.randn(ndims) for i in range(nwalkers)]
    #for i in range(20):
   #     print(i,lnPrior_func(initial_pos[i],fitter ,rmInd))
      #  print(i, lnPosterior_pf(initial_pos[i],fitter,lnPrior_func,rmInd) )
    sampler = emcee.EnsembleSampler(nwalkers, ndims, lnPosterior_pf, args=(fitter, lnPrior_func, rmInd))
    sampler.reset()
    final_state = sampler.run_mcmc(initial_pos,numsteps,progress=True)
    return sampler

In [None]:
sampler = run_emcee(bestfits_n1[1], fitters_n1[1] ,rm_inds[1],numsteps=2)


In [None]:
dir(sampler)

In [None]:
l=['X0_1','Y0_1','I_tot_1','PA_2','ell_bulge_2','I_e_2','r_e_2','X0_2','Y0_2','I_tot_3','PA_4','ell_bulge_4','I_e_4','r_e_4']

In [None]:
def PlotAllWalkers(sample_chain, yAxisLabel, figtitle):
    n = len(yAxisLabel)
    fig,ax = plt.subplots(n,1,figsize=(8, n*3))
    nWalkers = sample_chain.shape[0]
    for j in range(len(yAxisLabel)):
        for i in range(nWalkers):
            ax[j].plot(sample_chain[i,:,j], color='0.5')
    [ax[i].set_xlabel('Step number') for i in range(len(yAxisLabel))]
    [ax[i].set_ylabel(yAxisLabel[i]) for i in range(len(yAxisLabel))]
    ax[0].set_title(figtitle)
    
    fig.tight_layout();
 

PlotAllWalkers(sampler.chain, l, "")


In [None]:
with open("chain_n1_1_500_pf.pkl", "rb") as file:
    d = pickle.load(file)
PlotAllWalkers(d['chain'], l, "")

In [None]:
attributes_and_methods = {}
atts = ['acceptance_fraction','chain','flatchain', 'flatlnprobability', 'lnprobability']

for attr in atts:
    # Exclude callable attributes if they have special conditions
    if attr not in ['acor', 'get_autocorr_time', 'run_mcmc', 'sample']:
        attributes_and_methods[attr] = getattr(sampler, attr)
        
attributes_and_methods

In [None]:
with open("chain_n1_1_500_pf.pkl", 'wb') as file:
    pickle.dump(attributes_and_methods, file)

In [None]:
with open("chain_n1_1_500_pf.pkl", 'rb') as file:
    c = pickle.load(file)

In [None]:
PlotAllWalkers(c['chain'], l, "") 