In [1]:
import warnings
warnings.filterwarnings("ignore", "Wswiglal-redir-stdio")

In [2]:
from pathlib import Path
import sqlite3

from astropy.coordinates import EarthLocation, ICRS, SkyCoord
from astropy.time import Time
from astropy.table import QTable, join
import astropy_healpix as ah
from astropy import units as u
from matplotlib import pyplot as plt
from matplotlib import patheffects
from matplotlib.colors import LinearSegmentedColormap
from m4opt.missions import uvex as mission
from m4opt.models import observing
import ligo.skymap
from ligo.skymap.io import read_sky_map
import numpy as np
import seaborn as sns
from scipy import stats
import synphot
from tqdm.auto import tqdm

from rate_stats import format_with_errorbars, poisson_lognormal_rate_quantiles
from detection_probability import get_detection_probability_unknown_position, get_detection_probability_known_position

ligo.skymap.omp.num_threads = 1

In [3]:
base_path = Path('/home/lsinger/lustre/runs_SNR-10')
runs = ['O5HLVK', 'O6HLVK']

In [None]:
# Read summary data for all events
event_tables_by_run = {}
for run in tqdm(runs):
    event_table = join(
        QTable.read(base_path / run / 'farah' / 'allsky.dat', format='ascii'),
        QTable.read(base_path / run / 'farah' / 'injections.dat', format='ascii'),
    )
    assert (event_table['coinc_event_id'] == np.arange(len(event_table))).all()
    event_tables_by_run[run] = event_table

In [None]:
# Read all observing plans
plans_by_run = {}
for run, event_table in event_tables_by_run.items():
    plans_by_run[run] = plans = [QTable.read(base_path / run / 'farah' / 'allsky' / f'{i}.ecsv') for i in tqdm(range(len(event_tables_by_run[run])))]

    # Save objective values for all plans
    event_table['objective_value'] = [plan.meta['objective_value'] for plan in plans]

# Get planner arguments (doesn't matter which event)
plan_args = {**plans[0].meta['args']}
plan_args.pop('skymap')

hpx = ah.HEALPix(nside=plan_args['nside'], frame=ICRS(), order='nested')

In [None]:
for event_table, plans in zip(event_tables_by_run.values(), plans_by_run.values()):
    event_table['detection_probability_known_position'] = [
        get_detection_probability_known_position(plan, event_row, plan_args)
        for plan, event_row in zip(tqdm(plans), event_table)
    ]

In [None]:
for event_table, plans, run in zip(
    event_tables_by_run.values(), plans_by_run.values(), plans_by_run.keys()
):
    event_table['detection_probability_unknown_position'] = [
        get_detection_probability_unknown_position(
            plan,
            read_sky_map(base_path / run / 'farah' / 'allsky' / f'{event_id}.fits', moc=True), plan_args
        )
        for event_id, plan in enumerate(tqdm(plans))
    ]

In [8]:
with observing(
    observer_location=EarthLocation(0 * u.m, 0 * u.m, 0 * u.m),
    target_coord=hpx.healpix_to_skycoord(np.arange(hpx.npix)),
    obstime=Time('2025-01-01'),
):
    limmag = mission.detector.get_limmag(
        plan_args['snr'],
        min(plan_args['deadline'] - plan_args['delay'], plan_args['exptime_max']),
        synphot.SourceSpectrum(synphot.ConstFlux1D, amplitude=0 * u.ABmag), plan_args['bandpass']
    ).max()

skymap_area_cl = 90
min_area = (3.5 * u.deg)**2

chisq_ppf = stats.chi2(df=2).ppf
area_factor = (chisq_ppf(skymap_area_cl / 100) / chisq_ppf(plan_args['cutoff']))
max_area = (area_factor * min_area * (plan_args['deadline'] - plan_args['delay']) / (plan_args['visits'] * plan_args['exptime_min']))
max_distance = 10**(0.2 * (limmag.to_value(u.mag) - (plan_args['absmag_mean'] - stats.norm.ppf(1 - plan_args['cutoff']) * plan_args['absmag_stdev']) - 25)) * u.Mpc
crossover_distance = max_distance * (min_area / max_area)**.25

In [None]:
cmap = plt.get_cmap('cool')
cmap = LinearSegmentedColormap.from_list('truncated_cool', cmap(np.linspace(1/3, 1)))
with plt.style.context('seaborn-v0_8-paper'):
    fig = plt.figure(figsize=(7 + 1/3, 3))
    gs = plt.GridSpec(1, 3, figure=fig, width_ratios=(1, 1, 0.05), wspace=0.05, bottom=0.2, left=0.1, right=0.925)
    ax = fig.add_subplot(gs[0], aspect=0.25)
    # fig, axs = plt.subplots(1, 3, sharex=True, sharey=True, tight_layout=True, figsize=(7 + 1/3, 3), subplot_kw=dict(aspect=0.25))
    # ax = axs[0]
    ax.set_xlim(5e1, 5e3)
    ax.set_ylim(5e-2, u.spat.to(u.deg**2))
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_ylabel(f'{skymap_area_cl}% credible area (deg$^2$)')
    axs = [ax]
    ax = fig.add_subplot(gs[1], sharex=ax, sharey=ax, aspect=0.25)
    plt.setp(ax.yaxis.get_ticklabels(), visible=False)
    ax.yaxis.get_label().set_visible(False)
    axs.append(ax)

    for ax, (run, event_table) in zip(axs, event_tables_by_run.items()):
        ax.set_xlabel('Luminosity distance (Mpc)')
        ax.fill_between(
            u.Quantity([ax.get_xlim()[0] * u.Mpc, crossover_distance, max_distance, max_distance, ax.get_xlim()[1] * u.Mpc]).to_value(u.Mpc),
            u.Quantity([max_area, max_area, min_area, ax.get_ylim()[0] * u.deg**2, ax.get_ylim()[0] * u.deg**2]).to_value(u.deg**2),
            np.tile(ax.get_ylim()[1], 5),
            color='gainsboro'
        )
        ax.scatter('distance', f'area({skymap_area_cl})', s=1, facecolor='silver', edgecolor='none', data=event_table)
        scatter = ax.scatter('distance', f'area({skymap_area_cl})', s=event_table['detection_probability_known_position'] * 30, c=event_table['objective_value'], cmap=cmap, vmin=0, vmax=1, data=event_table)
        lines, = ax.plot(
            u.Quantity([ax.get_xlim()[0] * u.Mpc, crossover_distance, max_distance, max_distance]).to_value(u.Mpc),
            u.Quantity([max_area, max_area, min_area, ax.get_ylim()[0] * u.deg**2]).to_value(u.deg**2),
            color='tab:blue'
        )
        kwargs = dict(
            color=lines.get_color(),
            ha='center',
            va='bottom',
            rotation_mode='anchor',
            linespacing=0.1,
            path_effects=[patheffects.withStroke(linewidth=2, foreground='white')],
            fontsize=0.8 * plt.rcParams['axes.labelsize']
        )
        ax.text(np.sqrt(ax.get_xlim()[0] * crossover_distance.to_value(u.Mpc)), max_area.to_value(u.deg**2), 'Max area\n', **kwargs)
        ax.text(max_distance.to_value(u.Mpc), np.sqrt(ax.get_ylim()[0] * min_area.to_value(u.deg**2)), 'Max distance\n', rotation=-90, **kwargs)
        ax.text(np.sqrt(crossover_distance.to_value(u.Mpc) * max_distance.to_value(u.Mpc)), np.sqrt(min_area.to_value(u.deg**2) * max_area.to_value(u.deg**2)), 'Area $\propto$ distance$^{-4}$\n', rotation=-45, **kwargs)
        ax.set_title(run[:2])
    cbar = plt.colorbar(scatter, cax=fig.add_subplot(gs[2]))
    cbar.set_label('Objective value', rotation=-90, va='bottom')
    fig.savefig('../figures/area-distance.pdf')

In [None]:
log_simulation_effective_rate_by_run = {}
for run in runs:
    with sqlite3.connect(f"file:{base_path / run / 'farah' / 'events.sqlite'}?mode=ro", uri=True) as db:
        (comment,), = db.execute("SELECT comment FROM process WHERE program = 'bayestar-inject'")
    log_simulation_effective_rate_by_run[run] = np.log(u.Quantity(comment).to_value(u.Gpc**-3 * u.yr**-1))
log_simulation_effective_rate_by_run

In [None]:
# O3 R&P paper Table II row 1 last column:
# 5%, 50%, and 95% quantiles of the total merger rate
# in Gpc^-3 yr^-1.
# See https://doi.org/10.1103/PhysRevX.13.011048
lo = 100
mid = 240
hi = 510

standard_90pct_interval, = np.diff(stats.norm.interval(0.9))
log_target_rate_mu = np.log(mid)
log_target_rate_sigma = np.log(hi / lo) / standard_90pct_interval
log_target_rate_mu, log_target_rate_sigma

In [12]:
prob_quantiles = np.asarray([0.5, 0.05, 0.95])
run_duration = 1.5  # years
mu = np.asarray([
    log_target_rate_mu + np.log(run_duration) - log_simulation_effective_rate_by_run[run] + np.log([
        np.sum(_) for _ in [
            event_tables_by_run[run]['objective_value'] > 0,
            event_tables_by_run[run]['detection_probability_known_position'],
        ]
    ]) for run in runs
])

rate_quantiles = poisson_lognormal_rate_quantiles(prob_quantiles[np.newaxis, np.newaxis, :], mu.T[:, :, np.newaxis], log_target_rate_sigma)

with open('../tables/selected-detected.tex', 'w') as f:
    for i, (label, row) in enumerate(zip(['Number of events selected', 'Number of events detected'], rate_quantiles)):
        print(label, *('${}_{{-{}}}^{{+{}}}$'.format(*format_with_errorbars(*col)) for col in row), sep=' & ', end=' \\\\\n' if i < len(runs) - 1 else '\n', file=f)

In [None]:
colors = sns.color_palette('Paired')
xlabels = [run[:2] for run in runs]

ax = plt.axes()
# ax.bar(xlabels, rate_quantiles[0, :, 2])
# ax.bar(xlabels, rate_quantiles[0, :, 0])
x = np.arange(2)
width = 0.4
offset = 0.4
ax.bar(x, rate_quantiles[0, :, 2], width=width, color=colors[1])
ax.bar(x, rate_quantiles[0, :, 1], width=width, color=colors[0])
ax.hlines(rate_quantiles[0, :, 0], x - 0.5 * width, x + 0.5 * width, color='black')
ax.bar(x + offset, rate_quantiles[1, :, 2], width=width, color=colors[3])
ax.bar(x + offset, rate_quantiles[1, :, 1], width=width, color=colors[2])
ax.hlines(rate_quantiles[1, :, 0], x + offset - 0.5 * width, x + offset + 0.5 * width, color='black')


In [None]:
ax = plt.axes()
x = np.arange(2)
width = 0.4
ax.bar(x - width / 2, rate_quantiles[0, :, 0], width=width)
ax.bar(x + width / 2, rate_quantiles[1, :, 0], width=width)
ax.errorbar(x - width / 2, rate_quantiles[0, :, 0], np.vstack((rate_quantiles[0, :, 0] - rate_quantiles[0, :, 1], rate_quantiles[0, :, 2] - rate_quantiles[0, :, 0])), color='k', marker='D', linestyle='none', capsize=10, capthick=plt.rcParams['lines.linewidth'])
ax.errorbar(x + width / 2, rate_quantiles[1, :, 0], np.vstack((rate_quantiles[1, :, 0] - rate_quantiles[1, :, 1], rate_quantiles[1, :, 2] - rate_quantiles[1, :, 0])), color='k', marker='D', linestyle='none', capsize=10, capthick=plt.rcParams['lines.linewidth'])
ax.set_xticks(x)
ax.set_xticklabels([run[:2] for run in runs])

In [None]:
event_tables = event_tables_by_run['O6HLVK']
event_table[(event_table['objective_value'] > 0.5) & (event_table['area(90)'] > 100) & (event_table['distance'] > 200) & (event_table['distance'] < 1000)]

In [None]:
event_table[
    (np.abs(SkyCoord(event_table['longitude'], event_table['latitude'], unit=u.rad).galactic.b.deg) < 25) &
    (event_table['detection_probability_unknown_position'] > 0.5) &
    (event_table['area(90)'] > 50)
]