## Before you begin:
- Clone the repo, navigate to correct directory.
- Install all package requirements.

In [None]:
! git clone https://github.com/broker-workshop/tutorials.git

fatal: destination path 'tutorials' already exists and is not an empty directory.


In [None]:
%cd tutorials/ANTARES/data

/content/tutorials/ANTARES/data


In [None]:
%%capture
! pip install matplotlib pandas astropy astroquery ligo.skymap healpy pickle5 keras-tcn
! pip install git+https://github.com/deepchatterjeeligo/astrorapid.git@broker-workshop

The typical scenario considered here is:
* A binary neutron star coalescence is observed by LIGO/Virgo/KAGRA (LVK) gravitational-wave detectors
* The LVK alert contains a sky-localization, which is a probability distribution in sky coordinates.
    
  > GW data can only provide an sky-localization based on the strength of the signal, the bandwidth of the signal, and participating detectors. These can be order $\mathcal{O}(10)$ to $\mathcal{O}(1000)$ sq. deg. in sky

* There is an associated counterpart: the _kilonova_, like GW170817 ![GW170817](https://cfn-live-content-bucket-iop-org.s3.amazonaws.com/journals/2041-8205/848/2/L12/1/apjlaa91c9f1_lr.jpg?AWSAccessKeyId=AKIAYDKQL6LTV7YY2HIK&Expires=1618884299&Signature=03jjAeLtBiofb77vjD%2BNC3M1n3U%3D)
* There may be several contaminants in the field of view of the kilonova, specially as we get more sensitive telescopes. We need to down select potential objects.
* Communicate to scheduling facilities. For example [Treasure Map](http://treasuremap.space/).

We will consider the situation with simulated lightcurves.
- We have seen just one kilonova
- Data is proprietary at the time of GW discovery (at least from high probability region of skymap)
- Need simulated objects for mock tests. (see talk + demo on PLAsTiCC)

In [None]:
import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
try:
    lightcurves = pd.read_pickle('lightcurves.pickle')
except ValueError:
    import pickle5
    with open('lightcurves.pickle', 'rb') as f:
        lightcurves = pickle5.load(f)

## Here are 10 lightcurves (9 contaminants and 1 kilonova)

These were simulated considering an observing cadence of ZTF. The last one is the Kilonova (KN).

(For those interested, the data (SED models) are publicly available as a part of [SNANA package data](https://zenodo.org/record/4015340#.YHZ3ehJOlcA))

In [None]:
lightcurves.SIM_MODEL_NAME

## Plot the KN

In [None]:
kn = lightcurves.iloc[-1]  # the last one is the KN

In [None]:
plt.figure(figsize=(14, 6))

pkmjd_kn = kn.pkmjd
mjd_r_kn = kn.mjd_r - pkmjd_kn
mjd_g_kn = kn.mjd_g - pkmjd_kn

plt.errorbar(mjd_r_kn, kn.mag_r, yerr=kn.magerr_r, color='red',
             marker='*', fmt='o', label='R')
plt.errorbar(mjd_g_kn, kn.mag_g, yerr=kn.magerr_g, color='green',
             marker='^', fmt='o', label='g')
plt.ylim(plt.ylim()[::-1])
plt.legend(fontsize=14)
plt.xlabel('Time (days)', fontsize=14)
plt.title("Kilonova")
plt.show()

Kilonovae evolve rapidly. Given LSST cadence, discovery will be hard. Need rapid turnaround time (look at the light curve in comparison to SN below)

## Plot a few contaminant objects

In [None]:
plt.figure(figsize=(14, 6))

for idx, row in lightcurves.iterrows():
    if 'SALT' not in row.SIM_MODEL_NAME:  # plot SALT2 Ia; try a few others from list above
        continue
    pkmjd = row.pkmjd
    mjd_r = row.mjd_r - pkmjd
    mjd_g = row.mjd_g - pkmjd
    plt.errorbar(mjd_r, row.mag_r, yerr=row.magerr_r, color='red',
             marker='*', fmt='o', label='R')
    plt.errorbar(mjd_g, row.mag_g, yerr=row.magerr_g, color='green',
             marker='^', fmt='o', label='g')
plt.ylim(plt.ylim()[::-1])
plt.legend(fontsize=14)
plt.xlabel('Time (days)', fontsize=14)
plt.title("Contaminant")
plt.show()

## Associated skymap

In [None]:
import healpy as hp
from ligo.skymap.io import read_sky_map

In [None]:
skymap, *h = read_sky_map('skymap.fits')
plt.figure(figsize=(15, 10))
hp.mollview(skymap, fig=0, cmap='YlOrBr', cbar=False, hold=True,
            title="Mock LVK skymap associated with KN")

# plot locations of all objects
for idx, row in lightcurves.iterrows():
    hp.visufunc.projplot(
        row.SIM_RA, row.SIM_DEC, lonlat=True, marker='x',
        c='r' if row.true_label else 'b',
        markersize=15 if row.true_label else 5
    )
hp.graticule()

From the skymap, we can see that there are several supernovae which coincide with the kilonova in sky location as "contaminants". It is imperative to single out the kilonovae from the other events.

A simple filter
- check temporal coincidence
- check consistency with skymap
- crossmatch with a galaxy catalog

In [None]:
from astropy import coordinates, units as u
from astroquery.vizier import Vizier
from ligo.skymap.postprocess import crossmatch


gw_trigger_mjd = 58347
skymap_filename = 'skymap.fits'
skymap = read_sky_map(skymap_filename, moc=True)

# load a catalog
catalog, = Vizier.query_constraints(
    catalog='VII/281/glade2',
)

In [None]:
def simple_kn_filter(obj):
    """Simple filter to downselect an associated KN"""
    # check if there are any detections before GW trigger time
    mjd = np.hstack((obj.mjd_g, obj.mjd_r, obj.mjd_i))
    print(f"\nChecking for object {obj.SIM_MODEL_NAME.strip()}")
    if np.any(mjd < gw_trigger_mjd):
        print("Found detections before GW trigger time")
        return False
    
    # check consistency with skymap
    obj_location = coordinates.SkyCoord(
        obj.SIM_RA * u.deg,  # in reality this will be part of the alert
        obj.SIM_DEC * u.deg
    )
    # get a ligo.skymap crossmatch result
    crossmatch_result = crossmatch(
        skymap, coordinates=(obj_location,),
        contours=(0.5, 0.95)
    )
    # get line of sight p-value
    p_val = crossmatch_result.probdensity
    # get angular offset from mode
    offset, *_ = crossmatch_result.offset * u.deg
    # get sky areas of skymap
    area_fifty, area_ninety = crossmatch_result.contour_areas * u.deg**2
    # get searched area from posterior mode to target
    searched_area, *_ = crossmatch_result.searched_area * u.deg**2
    # put a threshold on searched area
    if searched_area > area_ninety:
        print(f"Searched area for {obj.SIM_MODEL_NAME.strip()} is {searched_area:.2f} > "
              f"90% area of {area_ninety:.2f}")
        return False
    
    # print objects matched to a galaxy catalog
    catalog_object_locations = coordinates.SkyCoord(catalog['RAJ2000'], catalog['DEJ2000'])
    idx, sep2d, dist3d = obj_location.match_to_catalog_sky(catalog_object_locations)
    closest_galaxy = catalog[idx]
    print(
        f"Found matching galaxy @ RA: {closest_galaxy['RAJ2000']:.3f} / "
        f"DEC: {closest_galaxy['DEJ2000']:.3f}"
    )
    print(
        f"Transient @ RA:{obj.SIM_RA:.3f} "
        f"DEC: {obj.SIM_DEC:.3f} / "
        f"Angular sep = {sep2d[0]:.3f}"
    )
    return True

In [None]:
[simple_kn_filter(row) for idx, row in lightcurves.iterrows()]

Thus simple temporal and spatial selection cuts can help us downselect. A more realistic situation may have an associated galaxy (it also may not since the galaxy catalog may be incomplete). These temporal information and spatial cuts based on the skymap was used during O3 operations in ANTARES.

## Classifying the lightcurve (WIP)

Early epoch classification code [RAPID](https://astrorapid.readthedocs.io/en/latest/) is a part of ANTARES.

But classification may be challenging just from the lightcurve. Since there may not be enough data. Hence we want to use the available contextual information available. Luckily, we get some contextual info for free - LVK skymap itself.

In [None]:
from astrorapid import Classify

classification = Classify(
    model_filepath='trained_model.hdf5',
    known_redshift=False,
    passbands=('g', 'r', 'i'),
    class_names = ('Pre-explosion', 'Kilonova', 'Other'),
    mintime=-5,
    timestep=3
)

In [None]:
def get_data_to_classify(obj):
    mjd = np.hstack((obj.mjd_g, obj.mjd_r, obj.mjd_i))
    sort_mask = np.argsort(mjd)

    flux = np.hstack((obj.fluxcal_g, obj.fluxcal_r, obj.fluxcal_i))
    fluxerr = np.hstack((obj.fluxcalerr_g, obj.fluxcalerr_r, obj.fluxcalerr_i))
    photflag = np.hstack((obj.photflag_g, obj.photflag_r, obj.photflag_i))
    passbands = np.array(obj.mjd_g.size*['g'] + obj.mjd_r.size*['r'] + obj.mjd_i.size*['i'])
    
    objid = obj.SIM_MODEL_NAME.strip()
    ra = obj.SIM_RA
    dec = obj.SIM_DEC
    redshift = obj.z
    mwebv = obj.SIM_MWEBV
    return (
        mjd[sort_mask], flux[sort_mask], fluxerr[sort_mask],
        passbands[sort_mask], photflag[sort_mask], ra, dec,
        objid, redshift, mwebv
    )

In [None]:
for idx, row in lightcurves.iterrows():
    print(f"\n#### True model: {row.SIM_MODEL_NAME.strip()} ####")
    lightcurve_data = get_data_to_classify(row)
    other_meta_data = dict(offset=row.offset, logprob=row.logprob)

    predictions, time_steps = classification.get_predictions(
        [lightcurve_data], return_predictions_at_obstime=False,
        other_meta_data=[other_meta_data]
    )
    if not predictions:
        continue
    kn_prediction = predictions[0].T[1]
    print("Time:", time_steps[0] - gw_trigger_mjd)
    print("KN class probabilities:", kn_prediction)