In [51]:
from astropy.coordinates import SkyCoord
import numpy as np
from IntegralQuery import SearchQuery, IntegralQuery, Filter, Range
from IntegralPointingClustering import ClusteredQuery
import astropy.io.fits as fits
from astropy.table import Table
from datetime import datetime
import matplotlib.pyplot as plt
import math
from numba import njit
from pyspi.utils.function_utils import find_response_version
from pyspi.utils.response.spi_response_data import ResponseDataRMF
from pyspi.utils.response.spi_response import ResponseRMFGenerator
from pyspi.utils.response.spi_drm import SPIDRM
from pyspi.utils.livedets import get_live_dets
from astromodels import Powerlaw, Log_uniform_prior, Uniform_prior, PointSource, SpectralComponent, Model
from chainconsumer import ChainConsumer
import pymultinest
import os
import astropy.time as at
from scipy.stats import poisson
import pickle

In [52]:
rsp_bases = tuple([ResponseDataRMF.from_version(i) for i in range(5)])

Using the irfs that are valid between Start and 03/07/06 06:00:00 (YY/MM/DD HH:MM:SS)
Using the irfs that are valid between 03/07/06 06:00:00 and 04/07/17 08:20:06 (YY/MM/DD HH:MM:SS)
Using the irfs that are valid between 04/07/17 08:20:06 and 09/02/19 09:59:57 (YY/MM/DD HH:MM:SS)
Using the irfs that are valid between 09/02/19 09:59:57 and 10/05/27 12:45:00 (YY/MM/DD HH:MM:SS)
Using the irfs that are valid between 10/05/27 12:45:00 and present (YY/MM/DD HH:MM:SS)


In [53]:
@njit
def b_maxL_2(m, t, C):
    first = C[0]+C[1]-(m[0]+m[1])*(t[0]+t[1])
    root = (C[0]+C[1]+(m[0]-m[1])*(t[0]+t[1]))**2-4*C[0]*(m[0]-m[1])*(t[0]+t[1])
    res = (first+np.sqrt(root))/(2*(t[0]+t[1]))
    if res < 0:
        return 0
    return res

@njit #################### needs more testing, can do without error?
def b_maxL_3(m, t, C):
    mt = m[0] + m[1] + m[2]
    tt = t[0] + t[1] + t[2]
    Ct = C[0] + C[1] + C[2]
    a = -tt
    b = -tt*mt + Ct
    c = Ct*mt - C[0]*m[0] - C[1]*m[1] - C[2]*m[2] -tt*(m[0]*m[1] + m[1]*m[2] + m[2]*m[0])
    d = C[0]*m[1]*m[2] + C[1]*m[2]*m[0] + C[2]*m[0]*m[1] - tt*m[0]*m[1]*m[2]
    D0 = b**2 - 3*a*c
    D1 = 2*b**3 - 9*a*b*c + 27*(a**2)*d
        
    if D0 == 0. and D1 == 0.:
        return -b/(3*a)
    
    C0 = ((D1 + np.sqrt(D1**2 - 4*D0**3 + 0j)) / 2)**(1/3)
    
    if C0 == 0:
        C0 = ((D1 - np.sqrt(D1**2 - 4*D0**3 + 0j)) / 2)**(1/3)
        
    x0 = -1/(3*a) * (b + C0 + D0/C0)
    
    if x0.real < 0:
        return 0.
    
    return x0.real

In [54]:
@njit
def logLcore(
    spec_binned,
    pointings,
    dets,
    resp_mats,
    num_sources,
    t_elapsed,
    counts
):
    logL=0
    for p_i in range(len(pointings)):
        for d_i in range(len(dets[p_i])):
            n_p = len(pointings[p_i])
            m = np.zeros((n_p, len(resp_mats[p_i][0][0][0,0,:])))
            
            t_b = np.zeros(n_p)
            for t_i in range(n_p):
                t_b[t_i] = t_elapsed[p_i][t_i][d_i]
            C_b = np.zeros(n_p)
            
            for s_i in range(num_sources):
                for m_i in range(n_p):
                    m[m_i,:] += np.dot(spec_binned[s_i,:], resp_mats[p_i][s_i][m_i][d_i])
            for e_i in range(len(m[0])):
                m_b = m[:,e_i]
                for C_i in range(n_p):
                    C_b[C_i] = counts[p_i][C_i][d_i, e_i]
                    
                if n_p == 2:
                    b = b_maxL_2(m_b, t_b, C_b)
                elif n_p == 3:
                    b = b_maxL_3(m_b, t_b, C_b)
                else:
                    print()
                    print("b_maxL is not defined")
                    print()
                    return 0.
                for m_i in range(n_p):
                    logL += (counts[p_i][m_i][d_i, e_i]*math.log(t_elapsed[p_i][m_i][d_i]*(m[m_i,e_i]+b))
                            -t_elapsed[p_i][0][d_i]*(m[m_i,e_i]+b))
    return logL

In [55]:
def extract_pointing_info(path, p_id):
    num_dets = 19
    with fits.open(f"{path}/pointing.fits") as file:
        t = Table.read(file[1])
        index = np.argwhere(t["PTID_ISOC"]==p_id[:8])
        
        if len(index) < 1:
            raise Exception(f"{p_id} not found")

        pointing_info = t[index[-1][0]]
        
        t1 = at.Time(f'{pointing_info["TSTART"]+2451544.5}', format='jd').datetime
        time_start = datetime.strftime(t1,'%y%m%d %H%M%S')
            
    with fits.open(f"{path}/dead_time.fits") as file:
        t = Table.read(file[1])
        
        time_elapsed = np.zeros(num_dets)
        
        for i in range(num_dets):
            for j in index:
                time_elapsed[i] += t["LIVETIME"][j[0]*85 + i]
        
    with fits.open(f"{path}/energy_boundaries.fits") as file:
        t = Table.read(file[1])
        
        energy_bins = np.append(t["E_MIN"], t["E_MAX"][-1])
    
    with fits.open(f"{path}/evts_det_spec.fits") as file:
        t = Table.read(file[1])
        
        counts = np.zeros((num_dets, len(energy_bins)-1))
        for i in range(num_dets):
            for j in index:
                counts[i, : ] += t["COUNTS"][j[0]*85 + i]

    return time_start, time_elapsed, energy_bins, counts

In [56]:
def generate_resp_mat(
    rmfs,
    len_dets,
    len_ebs,
    len_emod,
    ra,
    dec,
):
    sds = np.empty(0)
    for d in range(len_dets):
        sd = SPIDRM(rmfs[d], ra, dec)
        sds = np.append(sds, sd.matrix.T)
    return sds.reshape((len_dets, len_emod-1, len_ebs-1))

In [57]:
def calc_mahalanobis_dist(summary, cov, true_vals):
    fit_val = np.array([i[1] for i in summary.values()])
    fit_cov = cov[1]
    rel_distance = []
    
    for i in range(len(true_vals)):
        dif = fit_val - true_vals[i]
        
        rel_distance.append(np.sqrt(
            np.linalg.multi_dot([dif, np.linalg.inv(fit_cov), dif])
        ))
        
    return np.array(rel_distance)

In [58]:
class MultinestClusterFit:
    def __init__(
        self,
        pointings,
        source_model,
        energy_range,
        emod,
        binning_func,
        true_values=None,
        folder=None,
    ):
        self._pointings = pointings
        self._source_model = source_model
        self._binning_func = binning_func
        self._energy_range = energy_range
        self._emod = emod
        
        self._true_values = true_values
        self.set_folder(folder)
        
        self._prepare_fit_data()
        
        self._find_updatable_sources()
        
        self._initialize_resp_mats()
        
        self._run_multinest()
                
        self._extract_parameter_names_simple()
        self._parameter_names.extend(["$z$"])
            
        self._cc = ChainConsumer()
        self._chain = np.loadtxt('./chains/1-post_equal_weights.dat')
        self._cc.add_chain(self._chain, parameters=self._parameter_names, name='fit')
        
    def _run_multinest(self):
        num_sources = len(self._source_model.sources)
        
        def logLba_mult(trial_values, ndim=None, params=None):
            spec_binned = np.zeros((num_sources, len(self._emod)-1))
            for i, parameter in enumerate(self._source_model.free_parameters.values()):
                parameter.value = trial_values[i]
            for i, source in enumerate(self._source_model.sources.values()):
                spec = source(self._emod)
                spec_binned[i,:] = (self._emod[1:]-self._emod[:-1])*(spec[:-1]+spec[1:])/2
            if 1 in self._updatable_sources:
                self._update_resp_mats()
            return logLcore(
                spec_binned,
                self._pointings,
                self._dets,
                self._resp_mats,
                num_sources,
                self._t_elapsed,
                self._counts,
            )
        
        def prior(params, ndim=None, nparams=None):
            for i, parameter in enumerate(self._source_model.free_parameters.values()):
                try:
                    params[i] = parameter.prior.from_unit_cube(params[i])

                except AttributeError:
                    raise RuntimeError(
                        "The prior you are trying to use for parameter %s is "
                        "not compatible with sampling from a unitcube"
                        % parameter.path
                    )

        num_params = len(self._source_model.free_parameters)
        
        # ###
        # trial_values = [4.5e-3, -2.08, 7e-3, -2.5]
        # spec_binned = np.zeros((num_sources, len(self._emod)-1))
        # for i, parameter in enumerate(self._source_model.free_parameters.values()):
        #     parameter.value = trial_values[i]
        # for i, source in enumerate(self._source_model.sources.values()):
        #     spec = source(self._emod)
        #     spec_binned[i,:] = (self._emod[1:]-self._emod[:-1])*(spec[:-1]+spec[1:])/2
        # print(logLcore(
        #     spec_binned,
        #     self._pointings,
        #     self._dets,
        #     self._resp_mats,
        #     num_sources,
        #     self._t_elapsed,
        #     self._counts,
        # ))
        # return None
        # ###

        if not os.path.exists("./chains"):
            os.mkdir("chains")
        sampler = pymultinest.run(
            logLba_mult, prior, num_params, num_params, n_live_points=800, resume=False, verbose=True
        )
    
    def _prepare_fit_data(self):
        ebs = []
        counts = []
        dets = []
        t_elapsed = []
        t_start = []
            
        for combination in self._pointings:
            c_time_start, c_time_elapsed = [], []
            for p_i, pointing in enumerate(combination):
                time_start, time_elapsed, energy_bins, counts_f = extract_pointing_info(pointing[1], pointing[0])
                c_time_start.append(time_start)
                dets_temp = get_live_dets(time=time_start, event_types=["single"])
                c_time_elapsed.append(time_elapsed[dets_temp])
                
                if p_i == 0:
                    dets_0 = dets_temp
                    energy_bins_0 = energy_bins
                    c_counts_f = counts_f[dets_0]
                else:
                    c_counts_f = np.append(c_counts_f, counts_f[dets_0], axis=0)
                    assert np.array_equal(dets_0, dets_temp), f"Active detectors are not the same for {combination[0][0]} and {combination[p_i][0]}"
                    assert np.array_equal(energy_bins_0, energy_bins), f"Energy bins are not the same for {combination[0][0]} and {combination[p_i][0]}"
                
            eb, c = self._binning_func(
                energy_bins_0,
                c_counts_f,
                self._energy_range
            )
            nd = len(dets_0)
            counts.append(tuple([c[i*nd : (i+1)*nd] for i in range(len(combination))]))
            ebs.append(eb)
            
            t_start.append(tuple(c_time_start))
            dets.append(dets_0)
            t_elapsed.append(tuple(c_time_elapsed))
                            
                
            # time_start1, time_elapsed1, energy_bins1, counts_f1 = extract_pointing_info(pair[0], pair[1])
            # time_start2, time_elapsed2, energy_bins2, counts_f2 = extract_pointing_info(pair[0], pair[2])
            # t_start.append((time_start1, time_start2))
            
            # dets1 = get_live_dets(time=time_start1, event_types=["single"])
            # dets2 = get_live_dets(time=time_start2, event_types=["single"])
            # assert np.array_equal(dets1, dets2), f"Active detectors are not the same for {pair[1]} and {pair[2]}"
            # dets.append(dets1)
            
            # t_elapsed.append((time_elapsed1[dets1], time_elapsed2[dets1]))
            
            # assert np.array_equal(energy_bins1, energy_bins2), f"Energy bins are not the same for {pair[1]} and {pair[2]}"
            
            # eb, c = binning_func(
            #     energy_bins1,
            #     np.append(counts_f1[dets1], counts_f2[dets1], axis=0),
            #     energy_range)
            # counts.append((c[:len(dets1)], c[len(dets1):]))
            # ebs.append(eb)
                
        self._ebs = tuple(ebs) 
        self._counts = tuple(counts)
        self._dets = tuple(dets)
        self._t_elapsed = tuple(t_elapsed)
        self._t_start = tuple(t_start)
    
    def _initialize_resp_mats(self):
        # index order: tuple(combination, source, pointing, np_array(dets, e_in, e_out))
        resp_mats = []
        rmfs = []
        
        for count, combination in enumerate(self._pointings):
            # version1 = find_response_version(self._t_start[count][0])
            # for pointing in range(1, len(combination)):
            #     version2 = find_response_version(self._t_start[count][pointing])
            #     assert version1 == version2, f"Response versions are not equal for {combination[0][0]} and {combination[pointing][0]}"
            # rsp_base = rsp_bases[version1]
            
            source_resp_mats = []
            
            dets = self._dets[count]
            ebs = self._ebs[count]
            
            for source in self._source_model.sources.values():
                combination_resp_mats = []
                combination_rmfs = []
                
                for pointing in range(len(combination)):
                    time = self._t_start[count][pointing]
                    version = find_response_version(time)
                    rsp_base = rsp_bases[version]
                    
                    pointing_rmfs = []
                    for d in dets:
                        pointing_rmfs.append(ResponseRMFGenerator.from_time(time, d, ebs, self._emod, rsp_base))
                    pointing_rmfs = tuple(pointing_rmfs)
                                        
                    combination_resp_mats.append(
                        generate_resp_mat(
                            pointing_rmfs,
                            len(dets),
                            len(ebs),
                            len(self._emod),
                            source.position.get_ra(),
                            source.position.get_dec(),
                        )
                    )
                    combination_rmfs.append(pointing_rmfs)
                    
                source_resp_mats.append(tuple(combination_resp_mats))
                    
            resp_mats.append(tuple(source_resp_mats))
            rmfs.append(tuple(combination_rmfs))
            
        self._resp_mats = tuple(resp_mats)
        if 1 in self._updatable_sources:
            self._updatable_rmfs = tuple(rmfs)

    def _update_resp_mats(self):
        for count, combination in enumerate(self._pointings):
            for source_num, source in enumerate(self._source_model.sources.values()):
                if self._updatable_sources[source_num] == 1:
                    for pointing in range(len(combination)):
                        self._resp_mats[count][source_num][pointing][:,:,:] = generate_resp_mat(
                            self._updatable_rmfs[count][pointing],
                            len(self._dets[count]),
                            len(self._ebs[count]),
                            len(self._emod),
                            source.position.get_ra(),
                            source.position.get_dec(),
                        )

    def _find_updatable_sources(self):
        keywords = ["position"]
        self._updatable_sources = np.zeros(len(self._source_model.sources), np.int8)
        for s_i, source in enumerate(self._source_model.sources.values()):
            for parameter in source.free_parameters.values():
                first_pos = parameter.path.find(".")
                second_pos = parameter.path.find(".", first_pos+1)
                if parameter.path[first_pos+1 : second_pos] in keywords:
                    self._updatable_sources[s_i] = 1

    def parameter_fit_distribution(self):
        assert not self._folder is None, "folder is not set"
        
        fig = self._cc.plotter.plot(
            parameters=self._parameter_names[:-1],
            # truth={'Crab K':true_values_main[0,0], 'Crab index':true_values_main[0,1]},
            figsize=1.5
        )
        
        plt.savefig(f"{self._folder}/parameter_fit_distributions.pdf")
        plt.close()
        
    def text_summaries(
        self,
        reference_values=True,
        pointing_combinations=True,
        parameter_fit_constraints=True
    ):
        assert not self._folder is None, "folder is not set"
        
        
        if reference_values:
            assert not self._true_values is None, "true_values not set"
            summary = self._cc.analysis.get_summary(parameters=self._true_values[0])
            cov = self._cc.analysis.get_covariance(parameters=self._true_values[0])
            rel_distances = calc_mahalanobis_dist(summary, cov, self._true_values[1])
            
            with open(f"{self._folder}/reference_values", "w") as f:
                f.write(f"{' : '.join(self._true_values[0])} : Rel. Dist.\n")
                for i in range(self._true_values[1].shape[0]):
                    f.write(f"{' : '.join([f'{j:.3}' for j in self._true_values[1][i,:]])} : {rel_distances[i]:.3}\n")
                
        if pointing_combinations:
            with open(f"{self._folder}/pointing_combinations", "w") as f:
                for combination in self._pointings:
                    f.write(f'{"  ".join(i[0] for i in combination)}\n')
        
        if parameter_fit_constraints:
            summary = self._cc.analysis.get_summary(parameters=self._parameter_names[:-1])
            with open(f"{self._folder}/parameter_fit_constraints", "w") as f:
                for param in self._parameter_names[:-1]:
                    f.write(f"{param}:\n")
                    try:
                        f.write(f"{summary[param][0]:.5}  {summary[param][1]:.5}  {summary[param][2]:.5}\n")
                    except:
                        f.write(f"None  {summary[param][1]:.5}  None\n")
    
    def ppc( # allow updatable sources!
        self,
        count_energy_plots=True,
        qq_plots=True
    ):
        assert self._folder is not None, "folder is not set"
        
        s, b = self._calc_rates()
        
        for c_i, combination in enumerate(self._pointings):
            for p_i in range(len(combination)):
            
                if count_energy_plots:
                    self._count_energy_plot(
                        b[c_i][p_i],
                        s[c_i][p_i],
                        self._ebs[c_i],
                        self._counts[c_i][p_i],
                        self._dets[c_i],
                        combination[p_i][0]
                    )
                if qq_plots:
                    self._qq_plot(
                        b[c_i][p_i],
                        s[c_i][p_i],
                        self._counts[c_i][p_i],
                        self._dets[c_i],
                        combination[p_i][0]
                    )
            
    def _calc_rates(self):
        source_rate = []
        background_rate = []
        for c_i, combination in enumerate(self._pointings):
            source_rate.append(np.zeros((len(combination), len(self._dets[c_i]), len(self._ebs[c_i])-1, len(self._chain))))
            background_rate.append(np.zeros((len(self._dets[c_i]), len(self._ebs[c_i])-1, len(self._chain))))
        
        num_sources = len(self._source_model.sources)
        
        for p_i, params in enumerate(self._chain):
            spec_binned = np.zeros((num_sources, len(self._emod)-1))
            for fp_i, parameter in enumerate(self._source_model.free_parameters.values()):
                parameter.value = params[fp_i]
            for s_i, source in enumerate(self._source_model.sources.values()):
                spec = source(self._emod)
                spec_binned[s_i,:] = (self._emod[1:]-self._emod[:-1])*(spec[:-1]+spec[1:])/2
            if 1 in self._updatable_sources:
                self._update_resp_mats()
            
            for c_i, combination in enumerate(self._pointings):
                for d_i in range(len(self._dets[c_i])):
                    for s_i in range(num_sources):
                        for m_i in range(len(combination)):
                            source_rate[c_i][m_i,d_i,:,p_i] += np.dot(spec_binned[s_i,:], self._resp_mats[c_i][s_i][m_i][d_i])
                    for e_i in range(len(self._ebs[c_i])-1):
                        s_b = np.array([source_rate[c_i][i,d_i,e_i,p_i] for i in range(len(combination))])
                        t_b = np.array([self._t_elapsed[c_i][i][d_i] for i in range(len(combination))])
                        C_b = np.array([self._counts[c_i][i][d_i, e_i] for i in range(len(combination))])
                        if len(combination) == 2:
                            background_rate[c_i][d_i,e_i,p_i] = b_maxL_2(s_b, t_b, C_b)
                        elif len(combination) == 3:
                            background_rate[c_i][d_i,e_i,p_i] = b_maxL_3(s_b, t_b, C_b)
                        
        source_rates = []
        background_rates = []
        for c_i, combination in enumerate(self._pointings):
            c_source_rates = []
            c_background_rates = []
            
            for p_i in range(len(combination)):
                c_source_rates.append(
                    np.average(source_rate[c_i][p_i], axis=2) * self._t_elapsed[c_i][p_i][:,np.newaxis]
                )
                c_background_rates.append(
                    np.average(background_rate[c_i], axis=2) * self._t_elapsed[c_i][p_i][:,np.newaxis]
                )
            
            source_rates.append(c_source_rates)
            background_rates.append(c_background_rates)

        return source_rates, background_rates

    def _count_energy_plot(
        self,
        b,
        s,
        eb,
        c,
        dets,
        name
    ):
        fig, axes = plt.subplots(5,4, sharex=True, sharey=True, figsize=(10,10))
        axes = axes.flatten()
        
        predicted = b + s
        predicted_lower = poisson.ppf(0.16, predicted)
        predicted_upper = poisson.ppf(0.84, predicted)
        counts = c
        
        i=0
        for d in range(19):
            axes[d].text(.5,.9,f"Det {d}",horizontalalignment='center',transform=axes[d].transAxes)
            if d in dets:
                line1, = axes[d].step(eb[:-1], predicted[i], c="r")
                line2, = axes[d].step(eb[:-1], counts[i], c="k")
                axes[d].fill_between(eb[:-1], predicted_lower[i], predicted_upper[i], color="r", alpha=0.5)
                if i==0:
                    line1.set_label("Predicted Counts")
                    line2.set_label("Real Counts")
                i += 1
            axes[d].set_yscale("log")
        plt.subplots_adjust(hspace=0, wspace=0)
        plt.subplots_adjust(hspace=0, top=0.96, bottom=0.1)
        lgd = fig.legend(loc='center left', bbox_to_anchor=(0.9, 0.5), fontsize='x-large')
        
        fig.add_subplot(111, frameon=False)
        plt.tick_params(labelcolor='none', which='both', top=False, bottom=False, left=False, right=False)
        plt.xlabel("Detected Energy [keV]")
        plt.ylabel("Cumulative Counts")
        
        fig.savefig(f"{self._folder}/{name}_count_energy.pdf")
        plt.close()
    
    def _qq_plot(
        self,
        b,
        s,
        c,
        dets,
        name
    ):
        fig, axes = plt.subplots(5,4, sharex=True, sharey=True, figsize=(10,10))
        axes = axes.flatten()
        
        p = b + s
        predicted = np.cumsum(p, axis=1)
        predicted_lower = np.cumsum(poisson.ppf(0.16, p), axis=1)
        predicted_upper = np.cumsum(poisson.ppf(0.84, p), axis=1)
        counts = np.cumsum(c, axis=1)
        ma = np.amax(
            np.array([np.amax(counts, axis=1), np.amax(predicted, axis=1)]),
            axis=0
        )
        
        i=0
        for d in range(19):
            axes[d].text(.5,.9,f"Det {d}",horizontalalignment='center',transform=axes[d].transAxes)
            if d in dets:
                line2, = axes[d].plot([0, ma[i]], [0, ma[i]], ls="--", c="k")
                line1, = axes[d].plot(counts[i], predicted[i], c="r")
                axes[d].fill_between(counts[i], predicted_lower[i], predicted_upper[i], color="r", alpha=0.5)
                i += 1
        plt.subplots_adjust(hspace=0, wspace=0)
        plt.subplots_adjust(hspace=0, top=0.96, bottom=0.1)
                
        fig.add_subplot(111, frameon=False)
        plt.tick_params(labelcolor='none', which='both', top=False, bottom=False, left=False, right=False)
        plt.xlabel("Cumulative Real Counts")
        plt.ylabel("Cumulative Predicted Counts", labelpad=27)
        
        fig.savefig(f"{self._folder}/{name}_qq.pdf")
        plt.close()
        
    def _extract_parameter_names_simple(self):
        self._parameter_names = []
        for full_name in self._source_model.free_parameters.keys():
            source = full_name[ : full_name.find(".")]
            source = source[1:] if source[0]=="_" else source
            source = source.replace("__", "+").replace("_", " ")
            parameter = full_name[-1 * full_name[::-1].find(".") : ]
            self._parameter_names.extend([f"{source} {parameter}"])
    
    def set_folder(self, folder):
        if not folder is None:
            if not os.path.exists(f"./{folder}"):
                os.mkdir(folder)
        self._folder = folder

    
    

In [59]:
def rebin_data_exp(
    bins,
    counts,
    energy_range
):

    if energy_range[0]:
        for i, e in enumerate(bins):
            if e > energy_range[0]:
                bins = bins[i:]
                counts = counts[:,i:] ############should these be the same?
                break
    if energy_range[1]:
        for i, e in enumerate(bins):
            if e > energy_range[1]:
                bins = bins[:i]
                counts = counts[:,:i-1]
                assert i > 1, "Max Energy is too low"
                break
        
    min_counts = 5
    
    max_num_bins = 120
    min_num_bins = 1
    
    finished = False
    
    while not finished:
        num_bins = round((max_num_bins + min_num_bins) / 2)
        
        if num_bins == max_num_bins or num_bins == min_num_bins:
            num_bins = min_num_bins
            finished = True
        
        temp_bins = np.geomspace(bins[0], bins[-1], num_bins+1)
        
        new_bins, new_counts = rebin_closest(bins, counts, temp_bins)
        
        if np.amin(new_counts) < min_counts:
            max_num_bins = num_bins
        else:
            min_num_bins = num_bins
            
    return new_bins, new_counts
    
# @njit
def rebin_closest(bins, counts, temp_bins):
    counts = np.copy(counts)
    closest1 = len(bins) - 1
    for i in range(len(temp_bins)-2, 0, -1):
        closest2 = np.argpartition(
            np.absolute(bins - temp_bins[i]),
            0
        )[0]
        if closest1 - closest2 >= 2:
            counts[:,closest2] += np.sum(counts[:, closest2+1 : closest1], axis=1)
            counts = np.delete(
                counts,
                [j for j in range(closest2+1, closest1)],
                axis=1
            )
            bins = np.delete(
                bins,
                [j for j in range(closest2+1, closest1)]
            )
        closest1 = closest2
    return bins, counts


# counts = np.linspace(1,40,40).reshape((2,20))
# bins = np.linspace(1,21,21)
# temp_bins = np.geomspace(1,21,11)
# b, c = rebin_closest(bins, counts, temp_bins)
# print(bins)
# print(counts)
# print(np.sum(counts, axis=1))
# print(temp_bins)
# print(b)
# print(c)
# print(np.sum(c, axis=1))

In [14]:
def save_clusters(pointings, folder):
    if not os.path.exists(f"./{folder}"):
        os.mkdir(folder)
    with open(f"./{folder}/pointings.pickle", "wb") as f:
        pickle.dump(pointings, f)
        
def load_clusters(folder):
    with open(f"./{folder}/pointings.pickle", "rb") as f:
        pointings = pickle.load(f)
    return pointings

In [15]:
def extract_date_range(path):
    with fits.open(f"{path}/pointing.fits") as file:
        t = Table.read(file[1])
        t1 = at.Time(f'{t["TSTART"][0]+2451544.5}', format='jd')
        t1.format = "isot"
        t2 = at.Time(f'{t["TSTOP"][-1]+2451544.5}', format='jd')
        t2.format = "isot"
    return t1.value, t2.value

In [16]:
class PointingClusters: #add min time diff
    def __init__(
        self,
        orbit_paths,
        min_angle_dif,
        max_angle_dif,
        max_time_dif,
        radius_around_crab,
        min_time_elapsed,
        cluster_size_range,
        random_angle_dif_range=None,
    ):
        self._orbit_paths = orbit_paths
        self._min_angle_dif = min_angle_dif
        self._max_angle_dif = max_angle_dif
        self._max_time_dif = max_time_dif
        self._radius_around_crab = radius_around_crab
        self._min_time_elapsed = min_time_elapsed
        self._cluster_size_range = cluster_size_range
        self._random_angle_dif_range = random_angle_dif_range
        
        pointings = []
        self._get_scw_ids()
        cq = ClusteredQuery(
            self._scw_ids,
            angle_weight=0.,
            time_weight=1./self._max_time_dif,
            max_distance=1.,
            min_ang_distance=self._min_angle_dif,
            max_ang_distance=self._max_angle_dif,
            cluster_size_range = self._cluster_size_range,
            failed_improvements_max=3,
            suboptimal_cluster_size=max(1,self._cluster_size_range[0]),
            close_suboptimal_cluster_size=max(1,self._cluster_size_range[0])
        ).get_clustered_scw_ids()
        
        for size in range(self._cluster_size_range[0], self._cluster_size_range[1] + 1):
            for cluster in cq[size]:
                pointings.append(tuple([(i, f"crab_data/{i[:4]}") for i in cluster]))
                
        self.pointings = tuple(pointings)             
    
    def _get_scw_ids(self, print_results=False):
        p = SkyCoord(83.6333, 22.0144, frame="icrs", unit="deg")
        searchquerry = SearchQuery(position=p, radius=f"{self._radius_around_crab} degree",)
        cat = IntegralQuery(searchquerry)

        scw_ids_all = None
        for path in self._orbit_paths:
            f = Filter(
                SCW_TYPE="POINTING",
                TIME=Range(*extract_date_range(path))
            )
            if scw_ids_all:
                scw_ids_all = np.append(
                    scw_ids_all,
                    cat.apply_filter(f, return_coordinates=True, remove_duplicates=True),
                    axis=0
                )
            else:
                scw_ids_all = cat.apply_filter(f, return_coordinates=True, remove_duplicates=True)
        
        scw_ids = []
        
        multiple_files = []
        no_files = []
        no_pyspi = []
        
        num_dets = 19
        eb = np.geomspace(18, 2000, 5)
        emod = np.geomspace(18, 2000, 5)
        for scw_id in scw_ids_all:
            good = True
            with fits.open(f"{path}/pointing.fits") as file:
                t = Table.read(file[1])
                index = np.argwhere(t["PTID_ISOC"]==scw_id[0][:8])
                
                if len(index) < 1:
                    no_files.append(scw_id)
                    good = False
                    continue
                    
                elif len(index) > 1:
                    multiple_files.append(scw_id)
                    good = False
                                
                pointing_info = t[index[-1][0]]
            
                t1 = at.Time(f'{pointing_info["TSTART"]+2451544.5}', format='jd').datetime
                time_start = datetime.strftime(t1,'%y%m%d %H%M%S')
                                
                with fits.open(f"{path}/dead_time.fits") as file2:
                    t2 = Table.read(file2[1])
                    
                    time_elapsed = np.zeros(num_dets)
                    
                    for i in range(num_dets):
                        for j in index:
                            time_elapsed[i] += t2["LIVETIME"][j[0]*85 + i]
                                
                dets = get_live_dets(time=time_start, event_types=["single"])
                                
                if not np.amin(time_elapsed[dets]) > self._min_time_elapsed:
                    good = False
            
            try: # investigate why this is necessary
                version = find_response_version(time_start)
                rsp = ResponseRMFGenerator.from_time(time_start, dets[0], eb, emod, rsp_bases[version])
            except:
                no_pyspi.append(scw_id)
                good = False
                
            if good:
                scw_ids.append(scw_id)
                
        if print_results:
            print("Multiple Files:")
            print(multiple_files)
            print("No Files:")
            print(no_files)
            print("No PySpi:")
            print(no_pyspi)
            print("Good:")
            print(scw_ids)
        
        self._scw_ids = np.array(scw_ids)
        
    

In [17]:
def define_sources(source_funcs):    
    model = Model()
    for source_func, params in source_funcs:
        source_func(model, *params)
    return model

In [18]:
def crab_pl_fixed_pos(model, piv):
    ra, dec = 83.6333, 22.0144
    
    pl = Powerlaw()
    pl.piv = piv
    pl.K.prior = Log_uniform_prior(lower_bound=1e-6, upper_bound=1e0)
    pl.index.prior = Uniform_prior(lower_bound=-4, upper_bound=0)
    component1 = SpectralComponent("pl", shape=pl)
    ps = PointSource("Crab", ra=ra, dec=dec, components=[component1])
    
    model.add_source(ps)
    return model

def crab_pl_free_pos(model, piv):
    ra, dec = 83.6333, 22.0144
    angle_range = 5.
    
    pl = Powerlaw()
    pl.piv = piv
    pl.K.prior = Log_uniform_prior(lower_bound=1e-6, upper_bound=1e0)
    pl.index.prior = Uniform_prior(lower_bound=-4, upper_bound=0)
    component1 = SpectralComponent("pl", shape=pl)
    ps = PointSource("Crab", ra=ra, dec=dec, components=[component1])
    ps.position.ra.free = True
    ps.position.ra.prior = Uniform_prior(
        lower_bound = ra - abs(angle_range/np.cos(dec)),
        upper_bound = ra + abs(angle_range/np.cos(dec))
    )
    ps.position.dec.free = True
    ps.position.dec.prior = Uniform_prior(
        lower_bound = dec - angle_range,
        upper_bound = dec + angle_range
    )
    
    model.add_source(ps)
    return model

def _1A_0535_262_pl(model, piv):
    ra, dec = 84.7270, 26.3160
    
    pl = Powerlaw()
    pl.piv = piv
    pl.K.prior = Log_uniform_prior(lower_bound=1e-10, upper_bound=1e0)
    pl.index.prior = Uniform_prior(lower_bound=-4, upper_bound=0)
    component1 = SpectralComponent("pl", shape=pl)
    ps = PointSource("_1A_0535__262", ra=ra, dec=dec, components=[component1])
    
    model.add_source(ps)
    return model

def _4U_0517_17_pl(model, piv):
    ra, dec = 77.6896, 16.4986
    
    pl = Powerlaw()
    pl.piv = piv
    pl.index.min_value = -20.
    pl.K.prior = Log_uniform_prior(lower_bound=1e-10, upper_bound=1e-4)
    pl.index.prior = Uniform_prior(lower_bound=-20, upper_bound=10)
    component1 = SpectralComponent("pl", shape=pl)
    ps = PointSource("_4U_0517__17", ra=ra, dec=dec, components=[component1])
    
    model.add_source(ps)
    return model

def _4U_0614_09_pl(model, piv):
    ra, dec = 94.2800, 9.13700
    
    pl = Powerlaw()
    pl.piv = piv
    pl.index.max_value = 20.
    pl.K.prior = Log_uniform_prior(lower_bound=1e-10, upper_bound=1e-4)
    pl.index.prior = Uniform_prior(lower_bound=-10, upper_bound=20)
    component1 = SpectralComponent("pl", shape=pl)
    ps = PointSource("_4U_0614__09", ra=ra, dec=dec, components=[component1])
    
    model.add_source(ps)
    return model

def geminga_pl(model, piv):
    ra, dec = 98.4750, 17.7670
    
    pl = Powerlaw()
    pl.piv = piv
    pl.index.max_value = 15.
    pl.K.prior = Log_uniform_prior(lower_bound=1e-15, upper_bound=1e-5)
    pl.index.prior = Uniform_prior(lower_bound=-10, upper_bound=15)
    component1 = SpectralComponent("pl", shape=pl)
    ps = PointSource("Geminga", ra=ra, dec=dec, components=[component1])
    
    model.add_source(ps)
    return model

In [45]:
def true_values(include_position=False):
    piv = 40
    
    crab_parameters = np.array([[9.3, 1, -2.08, 83.6333, 22.0144],
                                [7.52e-4, 100, -1.99, 83.6333, 22.0144],
                                [11.03, 1, -2.1, 83.6333, 22.0144]])
    
    crab_values = np.zeros((len(crab_parameters), len(crab_parameters[0]) - 1))
    crab_values[:,0] = crab_parameters[:,0] * (piv / crab_parameters[:,1])**crab_parameters[:,2]
    crab_values[:,1:] = crab_parameters[:,2:]
    
    names = ["Crab K", "Crab index", "Crab ra", "Crab dec"]
    if include_position:
        crab_values = (names, crab_values)
    else:
        crab_values = (names[:-2], crab_values[:,:-2])
    
    return crab_values

In [42]:
folder = "orbit_1019"

# pointings = PointingClusters(
#     ("crab_data/1019",),
#     min_angle_dif=1.5,
#     max_angle_dif=4.,
#     max_time_dif=0.2,
#     radius_around_crab=5.,
#     min_time_elapsed=600.,
#     cluster_size_range=(2,2),
# ).pointings
# save_clusters(pointings, folder)

pointings = load_clusters(folder)
source_model = define_sources((
    (crab_pl_fixed_pos, (40,)),
    (_1A_0535_262_pl, (40,)),
))

multinest_fit = MultinestClusterFit(
    pointings,
    source_model,
    (None, 80),
    np.geomspace(18, 150, 50),
    rebin_data_exp,
    true_values=true_values(),
    folder=folder,
)

multinest_fit.parameter_fit_distribution()
multinest_fit.text_summaries()
multinest_fit.ppc()

 *****************************************************
 MultiNest v3.10
 Copyright Farhan Feroz & Mike Hobson
 Release Jul 2015

 no. of live points =  800
 dimensionality =    4
 *****************************************************
 Starting MultiNest
 generating live points
 live points generated, starting sampling
Acceptance Rate:                        0.998825
Replacements:                                850
Total Samples:                               851
Nested Sampling ln(Z):            **************
Importance Nested Sampling ln(Z): ************** +/-  0.999412
Acceptance Rate:                        0.998890
Replacements:                                900
Total Samples:                               901
Nested Sampling ln(Z):            5172337.139658
Importance Nested Sampling ln(Z): ************** +/-  0.999445
Acceptance Rate:                        0.989583
Replacements:                                950
Total Samples:                               960
Nested Sampling

In [62]:
folder = "triple_1019"

# save_clusters(
#     ((('101900550010', "crab_data/1019"), ('101900590010', "crab_data/1019"), ('101900600010', "crab_data/1019")),),
#     folder
# )

pointings = load_clusters(folder)
source_model = define_sources((
    (crab_pl_fixed_pos, (40,)),
    (_1A_0535_262_pl, (40,)),
))

multinest_fit = MultinestClusterFit(
    pointings,
    source_model,
    (None, 80),
    np.geomspace(18, 150, 50),
    rebin_data_exp,
    true_values=true_values(),
    folder=folder,
)

multinest_fit.parameter_fit_distribution()
multinest_fit.text_summaries()
multinest_fit.ppc()

 *****************************************************
 MultiNest v3.10
 Copyright Farhan Feroz & Mike Hobson
 Release Jul 2015

 no. of live points =  800
 dimensionality =    4
 *****************************************************
 Starting MultiNest
 generating live points
 live points generated, starting sampling
Acceptance Rate:                        0.994152
Replacements:                                850
Total Samples:                               855
Nested Sampling ln(Z):            **************
Importance Nested Sampling ln(Z): 6375212.611014 +/-  0.999415
Acceptance Rate:                        0.991189
Replacements:                                900
Total Samples:                               908
Nested Sampling ln(Z):            -618895.344040
Importance Nested Sampling ln(Z): 6375212.550871 +/-  0.999449
Acceptance Rate:                        0.980392
Replacements:                                950
Total Samples:                               969
Nested Sampling

ERROR: Interrupt received: Terminating
Exception ignored on calling ctypes callback function: <function run.<locals>.loglike at 0x7f8dacdd8820>
Traceback (most recent call last):
  File "/home/moej56153/.pyenv/versions/MT/lib/python3.9/site-packages/pymultinest/run.py", line 221, in loglike
    return LogLikelihood(cube, ndim, nparams)
  File "/tmp/ipykernel_342/2937187804.py", line 48, in logLba_mult
  File "/home/moej56153/.pyenv/versions/MT/lib/python3.9/site-packages/pymultinest/run.py", line 70, in interrupt_handler
    sys.exit(1)
SystemExit: 1


Acceptance Rate:                        0.699934
Replacements:                               4250
Total Samples:                              6072
Nested Sampling ln(Z):            6373819.642252
Importance Nested Sampling ln(Z): 6375289.539037 +/-  0.999918
Acceptance Rate:                        0.697712
Replacements:                               4300
Total Samples:                              6163
Nested Sampling ln(Z):            6373840.524432
Importance Nested Sampling ln(Z): 6375289.460006 +/-  0.999919
Acceptance Rate:                        0.697227
Replacements:                               4350
Total Samples:                              6239
Nested Sampling ln(Z):            6373860.068280
Importance Nested Sampling ln(Z): 6375289.398077 +/-  0.999920
Acceptance Rate:                        0.695872
Replacements:                               4400
Total Samples:                              6323
Nested Sampling ln(Z):            6373874.234335
Importance Nested Sampling 

In [45]:
folder = "orbit_1019_extended"

# pointings = PointingClusters(
#     ("crab_data/1019",),
#     min_angle_dif=1.5,
#     max_angle_dif=4.,
#     max_time_dif=0.2,
#     radius_around_crab=5.,
#     min_time_elapsed=600.,
#     cluster_size_range=(2,2),
# ).pointings
# save_clusters(pointings, folder)

pointings = load_clusters(folder)
source_model = define_sources((
    (crab_pl_free_pos, (40,)),
    (_1A_0535_262_pl, (40,)),
    (geminga_pl, (40,)),
    (_4U_0517_17_pl, (40,)),
    (_4U_0614_09_pl, (40,)),
))

multinest_fit = MultinestClusterFit(
    pointings,
    source_model,
    (None, 80),
    np.geomspace(18, 150, 50),
    rebin_data_exp,
    true_values=true_values(),
    folder=folder,
)

multinest_fit.parameter_fit_distribution()
multinest_fit.text_summaries()
multinest_fit.ppc()


 *****************************************************
 MultiNest v3.10
 Copyright Farhan Feroz & Mike Hobson
 Release Jul 2015

 no. of live points =  800
 dimensionality =   12
 *****************************************************
 Starting MultiNest
 generating live points
 live points generated, starting sampling
Acceptance Rate:                        0.997653
Replacements:                                850
Total Samples:                               852
Nested Sampling ln(Z):            **************
Importance Nested Sampling ln(Z): ************** +/-  0.999413
Acceptance Rate:                        0.994475
Replacements:                                900
Total Samples:                               905
Nested Sampling ln(Z):            **************
Importance Nested Sampling ln(Z): ************** +/-  0.999447
Acceptance Rate:                        0.988554
Replacements:                                950
Total Samples:                               961
Nested Sampling

ERROR: Interrupt received: Terminating
Exception ignored on calling ctypes callback function: <function run.<locals>.loglike at 0x7f37743ef310>
Traceback (most recent call last):
  File "/home/moej56153/.pyenv/versions/MT/lib/python3.9/site-packages/pymultinest/run.py", line 221, in loglike
    return LogLikelihood(cube, ndim, nparams)
  File "/tmp/ipykernel_326/1986421315.py", line 47, in logLba_mult
  File "/tmp/ipykernel_326/1986421315.py", line 219, in _update_resp_mats
  File "/tmp/ipykernel_326/3405850027.py", line 11, in generate_resp_mat
  File "/home/moej56153/.pyenv/versions/MT/lib/python3.9/site-packages/pyspi/utils/response/spi_drm.py", line 18, in __init__
    self._drm_generator.set_location(ra, dec)
  File "/home/moej56153/.pyenv/versions/MT/lib/python3.9/site-packages/pyspi/utils/response/spi_response.py", line 282, in set_location
    self._recalculate_response()
  File "/home/moej56153/.pyenv/versions/MT/lib/python3.9/site-packages/pyspi/utils/response/spi_response.py

Acceptance Rate:                        0.916730
Replacements:                               1200
Total Samples:                              1309
Nested Sampling ln(Z):            **************
Importance Nested Sampling ln(Z): ************** +/-  0.999618
Acceptance Rate:                        0.899928
Replacements:                               1250
Total Samples:                              1389
Nested Sampling ln(Z):            **************
Importance Nested Sampling ln(Z): ************** +/-  0.999640
Acceptance Rate:                        0.890411
Replacements:                               1300
Total Samples:                              1460
Nested Sampling ln(Z):            **************
Importance Nested Sampling ln(Z): ************** +/-  0.999657
Acceptance Rate:                        0.886991
Replacements:                               1350
Total Samples:                              1522
Nested Sampling ln(Z):            **************
Importance Nested Sampling 

In [60]:
folder = "movable_crab_1644_pair"

# save_clusters(
#     ((('164400150010', "crab_data/1644"), ('164400220010', "crab_data/1644")),),
#     folder
# )

pointings = load_clusters(folder)
source_model = define_sources((
    (crab_pl_free_pos, (40,)),
))

multinest_fit = MultinestClusterFit(
    pointings,
    source_model,
    (None, 80),
    np.geomspace(18, 150, 50),
    rebin_data_exp,
    true_values=true_values(),
    folder=folder,
)

multinest_fit.parameter_fit_distribution()
multinest_fit.text_summaries()
multinest_fit.ppc()

 *****************************************************
 MultiNest v3.10
 Copyright Farhan Feroz & Mike Hobson
 Release Jul 2015

 no. of live points =  800
 dimensionality =    4
 *****************************************************
 Starting MultiNest
 generating live points
 live points generated, starting sampling
Acceptance Rate:                        0.995316
Replacements:                                850
Total Samples:                               854
Nested Sampling ln(Z):            -441291.759004
Importance Nested Sampling ln(Z): 3534231.838453 +/-  0.999414
Acceptance Rate:                        0.990099
Replacements:                                900
Total Samples:                               909
Nested Sampling ln(Z):            2320782.623115
Importance Nested Sampling ln(Z): 3534231.776039 +/-  0.999450
Acceptance Rate:                        0.980392
Replacements:                                950
Total Samples:                               969
Nested Sampling