In [1]:
import warnings
warnings.filterwarnings("ignore", "Wswiglal-redir-stdio")

from pathlib import Path
import sqlite3

from astropy.coordinates import Distance, EarthLocation, ICRS
from astropy.time import Time
from astropy.table import QTable, join, unique
import astropy_healpix as ah
from astropy import units as u
from m4opt.fov import footprint_healpix
from matplotlib import pyplot as plt
from matplotlib import patheffects
from m4opt.missions import uvex as mission
from m4opt.models import observing, DustExtinction
import ligo.skymap
from ligo.skymap import distance
from ligo.skymap.io import read_sky_map
from ligo.skymap.bayestar import rasterize
import numpy as np
from scipy import stats
import synphot
from tqdm.auto import tqdm

ligo.skymap.omp.num_threads = 1

In [2]:
base_path = Path('/home/lsinger/lustre/runs_SNR-10/O5HLVK/farah')

In [None]:
# Read summary data for all events
event_table = join(
    QTable.read(base_path / 'allsky.dat', format='ascii'),
    QTable.read(base_path / 'injections.dat', format='ascii'),
)
assert (event_table['coinc_event_id'] == np.arange(len(event_table))).all()
event_table

In [None]:
# Read all observing plans
plans = [QTable.read(base_path / 'allsky' / f'{i}.ecsv') for i in tqdm(range(len(event_table)))]

# Save objective values for all plans
event_table['objective_value'] = [plan.meta['objective_value'] for plan in plans]

# Get planner arguments (doesn't matter which event)
plan_args = {**plans[0].meta['args']}
plan_args.pop('skymap')

# FIXME: not included in output at the moment
plan_args['cutoff'] = 0.1

hpx = ah.HEALPix(nside=plan_args['nside'], frame=ICRS(), order='nested')

In [None]:
def get_detection_probability_known_position(plan, event_row):
    if len(plan) == 0:
        return 0

    observations = plan[plan['action'] == 'observe'].filled()
    coords = observations['target_coord'].to_table()
    coords['i'] = np.arange(len(coords))
    i = np.sort(unique(coords, keys=['ra', 'dec'])['i'])
    fields = observations[i]

    target_ipix = hpx.lonlat_to_healpix(event_row['longitude'] * u.rad, event_row['latitude'] * u.rad)
    target_in_field = [target_ipix in footprint for footprint in footprint_healpix(hpx, mission.fov, fields['target_coord'], fields['roll'])]
    fields = fields[target_in_field]
    if len(fields) == 0:
        return 0

    with observing(
        observer_location=fields['observer_location'],
        target_coord=fields['target_coord'],
        obstime=(fields['start_time'] + 0.5 * fields['duration'])
    ):
        spectrum = synphot.SourceSpectrum(synphot.ConstFlux1D, amplitude=0 * u.ABmag) * synphot.SpectralElement(DustExtinction())
        limmag = mission.detector.get_limmag(
            plan_args['snr'],
            fields['duration'],
            spectrum,
            plan_args['bandpass']
        ).max()
    lim_absmag = limmag - Distance(event_row['distance'] * u.Mpc).distmod
    return stats.norm(loc=plan_args['absmag_mean'], scale=plan_args['absmag_stdev']).cdf(lim_absmag.to_value(u.mag))

event_table['detection_probability'] = [get_detection_probability_known_position(plan, event_row) for plan, event_row in zip(tqdm(plans), event_table)]

In [6]:
with observing(
    observer_location=EarthLocation(0 * u.m, 0 * u.m, 0 * u.m),
    target_coord=hpx.healpix_to_skycoord(np.arange(hpx.npix)),
    obstime=Time('2025-01-01'),
):
    limmag = mission.detector.get_limmag(
        plan_args['snr'],
        min(plan_args['deadline'] - plan_args['delay'], plan_args['exptime_max']),
        synphot.SourceSpectrum(synphot.ConstFlux1D, amplitude=0 * u.ABmag), plan_args['bandpass']
    ).max()

skymap_area_cl = 90
min_area = (3.5 * u.deg)**2

ppf = stats.chi(df=2).ppf
area_factor = (ppf(0.9) / ppf(plan_args['cutoff']))**2
max_area = (area_factor * min_area * (plan_args['deadline'] - plan_args['delay']) / (plan_args['visits'] * plan_args['exptime_min']))
max_distance = 10**(0.2 * (limmag.to_value(u.mag) - (plan_args['absmag_mean'] - stats.norm.ppf(1 - plan_args['cutoff']) * plan_args['absmag_stdev']) - 25)) * u.Mpc
crossover_distance = max_distance * (min_area / max_area)**.25

In [None]:
fig = plt.figure()
ax = plt.axes(aspect=0.25)
ax.scatter('distance', f'area({skymap_area_cl})', s=1, facecolor='0.7', edgecolor='none', data=event_table)
ax.scatter('distance', f'area({skymap_area_cl})', s=event_table['detection_probability'] * 50, c='objective_value', cmap='cool', vmin=-0.5, vmax=1, edgecolor='none', data=event_table)
ax.set_xlim(5e1, 5e3)
ax.set_ylim(5e-2, u.spat.to(u.deg**2))
lines, = ax.plot(
    u.Quantity([ax.get_xlim()[0] * u.Mpc, crossover_distance, max_distance, max_distance]).to_value(u.Mpc),
    u.Quantity([max_area, max_area, min_area, ax.get_ylim()[0] * u.deg**2]).to_value(u.deg**2)
)
kwargs = dict(
    color=lines.get_color(),
    ha='center',
    va='bottom',
    rotation_mode='anchor',
    path_effects=[patheffects.withStroke(linewidth=2.5, foreground='white')]
)
ax.text(np.sqrt(ax.get_xlim()[0] * crossover_distance.to_value(u.Mpc)), max_area.to_value(u.deg**2), 'Max area', **kwargs)
ax.text(max_distance.to_value(u.Mpc), np.sqrt(ax.get_ylim()[0] * min_area.to_value(u.deg**2)), 'Max distance', rotation=-90, **kwargs)
ax.text(np.sqrt(crossover_distance.to_value(u.Mpc) * max_distance.to_value(u.Mpc)), np.sqrt(min_area.to_value(u.deg**2) * max_area.to_value(u.deg**2)), 'Area $\propto$ distance$^{-4}$', rotation=-45, **kwargs)
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xlabel('Luminosity distance (Mpc)')
ax.set_ylabel(f'{skymap_area_cl}% credible area (deg$^2$)')

In [None]:
with sqlite3.connect(f"file:{base_path / 'events.sqlite'}?mode=ro", uri=True) as db:
    (comment,), = db.execute("SELECT comment FROM process WHERE program = 'bayestar-inject'")
simulation_effective_rate = u.Quantity(comment)
simulation_effective_rate

In [None]:
target_rate = 240 * u.Gpc**-3 * u.yr**-1
(np.sum(event_table['detection_probability']) * target_rate / simulation_effective_rate).to(u.dimensionless_unscaled)

In [None]:
event_table[(event_table['objective_value'] > 0.5) & (event_table['area(90)'] > 100) & (event_table['distance'] > 200) & (event_table['distance'] < 500)]

In [None]:
event_id = 800
exptimes = np.arange(300, 3700, 100)
fixed_exptime_plans = [QTable.read(base_path / 'allsky' / f'{event_id}-exptime-{exptime}s.ecsv') for exptime in tqdm(exptimes)]

In [23]:
def get_detection_probability_unknown_position(plan):
    if len(plan) == 0:
        return 0

    skymap_moc = read_sky_map(base_path / 'allsky' / f'{event_id}.fits', moc=True)
    skymap = rasterize(skymap_moc, order=ah.nside_to_level(plan_args['nside']))

    observations = plan[plan['action'] == 'observe'].filled()
    coords = observations['target_coord'].to_table()
    coords['i'] = np.arange(len(coords))
    i = np.sort(unique(coords, keys=['ra', 'dec'])['i'])
    fields = observations[i]

    durations = np.zeros(hpx.npix)
    for ipix, duration in zip(footprint_healpix(hpx, mission.fov, fields['target_coord'], fields['roll']), fields['duration'].to_value(u.s)):
        durations[ipix] = np.maximum(durations[ipix], duration)
    skymap['duration'] = durations * u.s
    skymap['ipix'] = np.arange(hpx.npix)

    skymap = skymap[durations > 0]

    with observing(
        observer_location=plan['observer_location'][0],
        target_coord=hpx.healpix_to_skycoord(skymap['ipix']),
        obstime=plan['start_time'][0]
    ):
        spectrum = synphot.SourceSpectrum(synphot.ConstFlux1D, amplitude=0 * u.ABmag) * synphot.SpectralElement(DustExtinction())
        skymap['limmag'] = mission.detector.get_limmag(
            plan_args['snr'],
            skymap['duration'],
            spectrum,
            plan_args['bandpass']
        )

    distmean, diststd, distnorm = distance.parameters_to_moments(skymap['DISTMU'], skymap['DISTSIGMA'])
    sigma2_log = np.log1p(np.square(diststd / distmean))
    logdistsigma = np.sqrt(sigma2_log)
    logdistmu = np.log(distmean) - 0.5 * sigma2_log

    absmagmu = plan_args['absmag_mean']
    absmagsigma = plan_args['absmag_stdev']
    a = 5 / np.log(10)
    appmagmu = absmagmu + a * logdistmu + 25
    appmagsigma = np.sqrt(np.square(absmagsigma) + np.square(a * logdistsigma))
    skymap['appmagmu'] = appmagmu
    skymap['appmagsigma'] = appmagsigma
    return (skymap['PROB'] * stats.norm(loc=skymap['appmagmu'], scale=skymap['appmagsigma']).cdf(skymap['limmag'])).sum()

In [None]:
prob_adaptive, *probs = [get_detection_probability_unknown_position(plan) for plan in tqdm([plans[event_id]] + fixed_exptime_plans)]

In [None]:
ax = plt.axes()
ax.plot(exptimes, probs)
ax.axhline(prob_adaptive)