# Investigate how sensitivity changes without SN2009hd #

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import math, pickle, os, logging, time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from numpy.lib.recfunctions import drop_fields

In [None]:
from flarestack.core.results import ResultsHandler, OverfluctuationError
from flarestack.core.experimental_results import ExperimentalResultHandler
from flarestack.data.icecube import ps_v002_p01
from flarestack.shared import plot_output_dir, flux_to_k
from flarestack.icecube_utils.reference_sensitivity import reference_sensitivity
from flarestack.analyses.ccsn.stasik_2017.ccsn_limits import limits, get_figure_limits, p_vals
from flarestack.analyses.ccsn.necker_2019.ccsn_helpers import sn_cats, updated_sn_catalogue_name, sn_time_pdfs, raw_output_dir, pdf_names, limit_sens
from flarestack.analyses.ccsn.necker_2019.build_catalogues_from_raw import load_catalogue
from flarestack.analyses.ccsn import get_sn_color
from flarestack.cluster import Submitter
from flarestack.utils.custom_dataset import custom_dataset

In [None]:
logging.getLogger().setLevel("DEBUG")
logging.debug('logging level is DEBUG')
logging.getLogger('matplotlib').setLevel('INFO')
logger = logging.getLogger('main')

In [None]:
# LLH Energy PDF
llh_energy = {
    "energy_pdf_name": "power_law",
}

cluster = True
ntrials = 500

# Spectral indices to loop over
gammas = [2.0, 2.5]

# minimizer to use
mh_name = 'fit_weights'

# base name
raw = raw_output_dir + f"/compare_sensitivity_SN2009hd/{mh_name}"

In [None]:
# set up emtpy dictionary to store the minimizer information in
full_res = dict()

for pdf_type in ['box', 'decay']:
    pdf_res = dict()
    pdf_type_name = f'{raw}/{pdf_type}'

    # loop over catalogues with and without SN2009hd
    for cat in ['with SN2009hd', 'without SN2009hd']:

        name = f'{pdf_type_name}/{cat.replace(" ", "_")}'

        # set up empty results dictionary for this catalogue
        cat_res = dict()

        # get the time pdfs for this catalogue
        time_pdfs = sn_time_pdfs('IIP', pdf_type=pdf_type)

        # Loop over time PDFs
        for llh_time in time_pdfs:

            # set up an empty results array for this time pdf
            time_res = dict()

            logging.debug(f'time_pdf is {llh_time}')

            if pdf_type == 'box': 
                time_key = str(llh_time["post_window"] + llh_time["pre_window"])
                pdf_time = float(time_key) if llh_time['pre_window'] == 0 else - float(time_key)
            else:
                time_key = str(llh_time['decay_time'])
                pdf_time = llh_time['decay_time'] / 364.25

            pdf_name = pdf_names(pdf_type, pdf_time)
            cat_path = updated_sn_catalogue_name('IIP', flagged=flagged)
            # load catalogue and select the closest source
            # that serves for estimating a good injection scale later
            catalogue = np.load(cat_path)
            logging.debug('catalogue dtype: ' + str(catalogue.dtype))
            closest_src = np.sort(catalogue, order="distance_mpc")[0]
            
            if cat == 'with SN2009hd':
                # load the catalogue from the original csv file 
                missed_IIPs = drop_fields(load_catalogue('IIP', 'missed_objects', include_flagged=True), 'weight')
                SN2009hd = missed_IIPs[missed_IIPs['source_name'] == 'SN2009hd']
                catalogue = np.append(catalogue, SN2009hd)
                cat_path = cat_path.strip(".npy") + '_with_SN2009hd.npy'
                np.save(cat_path, catalogue)
                
            
            if (cat == 'with SN2009hd') and not ('SN2009hd' in catalogue['source_name']):
                raise Exception(f'Catalogue is {cat} but SN2009hd is not in the catalogue!')
            if (cat == 'without SN2009hd') and ('SN2009hd' in catalogue['source_name']):
                raise Exception(f'Catalogue is {cat} but SN2009hd is in catalogue!')
            
            logging.debug('catalogue path: ' + str(cat_path))

            time_name = f'{name}/{time_key}'

            # set up the likelihood dictionary
            llh_dict = {
                "llh_name": "standard",
                "llh_energy_pdf": llh_energy,
                "llh_sig_time_pdf": llh_time,
                "llh_bkg_time_pdf": {
                    "time_pdf_name": "steady"
                }
            }

            # set up an injection dictionary which will be equal to the time pdf dictionary
            injection_time = llh_time

            # Loop over spectral indices
            for gamma in gammas:

                full_name = f'{time_name}/{gamma:.2f}'

                length = float(time_key)

                # try to estimate a good scale based on the sensitivity from the 7-yr PS sensitivity
                # at the declination of the closest source
                scale = 0.5 * (flux_to_k(reference_sensitivity(
                    np.sin(closest_src["dec_rad"]), gamma=gamma
                ) * 40 * math.sqrt(float(len(catalogue)))) * 200.) / length
                
                # in some cases the sensitivity is outside the tested range
                # to get a good sensitivity, adjust the scale in these cases
                
                if pdf_type == 'decay':
                    scale /= 20
                    if gamma == 2.5:
                        scale *= 0.15
                    if gamma == 2.0:
                        scale *= 0.5

                else:
                    scale *= 0.2
                    if gamma == 2.5:
                        scale *= 0.6

                # set up an injection dictionary and set the desired spectral index
                injection_energy = dict(llh_energy)
                injection_energy["gamma"] = gamma

                inj_dict = {
                    "injection_energy_pdf": injection_energy,
                    "injection_sig_time_pdf": injection_time,
                    "poisson_smear_bool": True,
                }

                # set up the final minimizer dictionary
                mh_dict = {
                    "name": full_name,
                    "mh_name": mh_name,
                    "dataset": custom_dataset(ps_v002_p01, catalogue,
                                              llh_dict["llh_sig_time_pdf"]),
                    "catalogue": cat_path,
                    "inj_dict": inj_dict,
                    "llh_dict": llh_dict,
                    "scale": scale,
                    "n_trials": ntrials,
                    "n_steps": 10
                }

                # call the main analyse function
                submitter = Submitter.get_submitter(
                    mh_dict, cluster, 5,
                    do_sensitivity_scale_estimation=False,
                    remove_old_results=True,
                    cluster_cpu=1,
                    h_cpu='03:59:59',
                    trials_per_task=1
                )
                logging.debug(f'submitter is {submitter}')
                time_res[gamma] = submitter

            cat_res[time_key] = time_res
            
        pdf_res[cat] = cat_res

    full_res[pdf_type] = pdf_res

In [None]:
for pdf_type, pdf_res in full_res.items():
    for cat, cat_res in pdf_res.items():
        for time, time_res in cat_res.items():
            for gamma, s in time_res.items():
                if (gamma == 2.5) and (pdf_type == 'box'):
                    logger.debug(s)
                    s.analyse()

In [None]:
for pdf_res in full_res.values():
    for cat_res in pdf_res.values():
        for time_res in cat_res.values():
            for s in time_res.values():
                logging.debug(f'waitng on {s}')
                s.wait_for_job()

In [None]:
stacked_sens_flux = {}
logging.getLogger().setLevel('WARNING')

for pdf_type, pdf_res in full_res.items():
    pdf_sens = dict()
    for cat, cat_res in pdf_res.items():
        cat_sens = dict()
        for time, time_res in cat_res.items():
            time_sens = dict()
            for gamma, s in time_res.items():
                
                try:
                    rh = ResultsHandler(s.mh_dict)
                    time_sens[gamma] = {
                        'sens': rh.sensitivity,
                        'sens_e': rh.sensitivity_err,
                        'sens_n': rh.sensitivity * rh.flux_to_ns,
                        'sens_n_e': rh.sensitivity_err * rh.flux_to_ns
                    }
                except OverfluctuationError as e:
                    logging.warning(f'{e} for {pdf_type} {cat} {time} {gamma}')
                
            cat_sens[time] = time_sens
        pdf_sens[cat] = cat_sens
    stacked_sens_flux[pdf_type] = pdf_sens

### plot results ###

In [None]:
sns.set()

In [None]:
cat_colors = {'with SN2009hd': 'k', 'without SN2009hd': 'b'}
gamma_ls = {2.: '-', 2.5: '--'}
cat_offset = {'with SN2009hd': 0.98, 'without SN2009hd': 1.02}

In [None]:
for pdf_type, pdf_sens in stacked_sens_flux.items():
    
    fig, ax = plt.subplots()
    handles = dict()
    
    for cat, cat_sens in pdf_sens.items():
        handles[cat] = dict()
        for time, time_sens in cat_sens.items():
            for gamma, sens in time_sens.items():
                
                offset = cat_offset[cat]
                h = ax.errorbar(float(time)*offset, sens['sens'], yerr=np.atleast_2d(sens['sens_e']).T, 
                                color=cat_colors[cat], capsize=5, marker='o', label=f'{cat} $\gamma$={gamma:.2f}')
                h[2][0].set_linestyle(gamma_ls[gamma])
                handles[cat][gamma] = h
                
    ax.set_xlabel(f'{pdf_type} length [d]')
    ax.set_ylabel(r'flux [GeV$^{-1}$ s$^{-1}$ cm$^{-2}$]')
    ax.set_title(f'{pdf_type} time profile')
    
    ax.set_yscale('log')
    ax.set_xscale('log')
    
    # create handles for legend
    legend_handles = list()
    for cat, cat_handles in handles.items():
        for gamma, h in cat_handles.items():
            legend_handles.append(h)
    ax.legend()
    
    plt.show()
    plt.close()

In [None]:
for pdf_type, pdf_sens in stacked_sens_flux.items():
    
    for gamma in gammas:
    
        fig, ax = plt.subplots()
        handles = dict()

        for cat, cat_sens in pdf_sens.items():
            start = True
            handles[cat] = dict()
            
            for time, time_sens in cat_sens.items():

                logger.debug(f'{pdf_type} {gamma} {cat} {time}: {sens["sens"]}')
                sens = time_sens[gamma]
                offset = cat_offset[cat]
                label = f'{cat} $\gamma$={gamma:.2f}' if start else ''
                h = ax.errorbar(float(time)*offset, sens['sens'], yerr=np.atleast_2d(sens['sens_e']).T, 
                                color=cat_colors[cat], capsize=5, marker='o', label=label)
                start=False

        unit = '[d]' if pdf_type == 'box' else '[y]'
        ax.set_xlabel(f'{pdf_type} length {unit}')
        ax.set_ylabel(r'flux [GeV$^{-1}$ s$^{-1}$ cm$^{-2}$]')
        ax.set_title(f'{pdf_type} time profile\n$\gamma$={gamma:.2f}')

        ax.set_yscale('log')
        ax.set_xscale('log')
        ax.legend()
        ax.set_xticks([float(t) for t in cat_sens.keys()])
        xticklabels = cat_sens.keys() if pdf_type == 'box' else [float(t)/364.25 for t in cat_sens.keys()]
        ax.set_xticklabels(xticklabels)
        
        plt.show()
        plt.close()

In [None]:
ref_to = 'with SN2009hd'
logging.getLogger().setLevel('DEBUG')

for pdf_type, pdf_sens in stacked_sens_flux.items():
    
    for gamma in gammas:
    
        fig, ax = plt.subplots()
        handles = dict()

        for cat, cat_sens in pdf_sens.items():
            start = True
            handles[cat] = dict()
            
            for time, time_sens in cat_sens.items():
                
                sens = time_sens[gamma]
                reference = pdf_sens[ref_to][time][gamma]['sens']
                logger.debug(f'{pdf_type} {gamma} {cat} {time}: {sens["sens"]/reference}')
                offset = cat_offset[cat]
                label = f'{cat} $\gamma$={gamma:.2f}' if start else ''
                h = ax.errorbar(float(time)*offset, sens['sens']/reference, yerr=np.atleast_2d(sens['sens_e']).T / reference, 
                                color=cat_colors[cat], capsize=5, marker='o', label=label)
                start=False

        ax.axhline(1, ls='--', alpha=0.5, color='k')
        unit = '[d]' if pdf_type == 'box' else '[y]'
        ax.set_xlabel(f'{pdf_type} length {unit}')
        ax.set_ylabel(r'flux / flux$_{\mathrm{with \; SN2009hd}}$')
        ax.set_title(f'{pdf_type} time profile\n$\gamma$={gamma:.2f}')

        ax.set_yscale('log')
        ax.set_xscale('log')
        ax.set_yticks([0.8, 0.9, 1, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7])
        logger.debug([float(t) for t in cat_sens.keys()])
        ax.set_xticks([float(t) for t in cat_sens.keys()])
        
        xticklabels = cat_sens.keys() if pdf_type == 'box' else [float(t)/364.25 for t in cat_sens.keys()]
        
        ax.set_xticklabels(xticklabels)
        ax.legend()

        filename = os.path.join(plot_output_dir(raw), f'{pdf_type}_gamma{gamma:.2f}.pdf')
        logger.debug(f'saving under {filename}')
        fig.tight_layout()
        fig.savefig(filename)
        
        plt.show()
        plt.close()