In [3]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [5]:
from trigger_utils.trigger_utils import get_kowalski_ztf_queue
import yaml

with open('../config/Credentials.yaml', 'r') as file:
    credentials = yaml.safe_load(file)
fritz_token = credentials['fritz_token']
ztf_allocation = credentials['allocation']

In [7]:
# print the current queue

current_ztf_queue = get_kowalski_ztf_queue(fritz_token, ztf_allocation)
current_ztf_queue

{'status': 'success',
 'data': {'queue_names': ['Caltech_Prince_2025-05-01',
   'Caltech_Prince_2025-05-03',
   'EP_2025-04-30_12',
   'EP_2025-04-30_14',
   'EP_2025-04-30_15',
   'EP_2025-04-30_16',
   'EP_2025-04-30_5',
   'EP_2025-04-30_6',
   'EP_2025-04-30_8',
   'EP_2025-04-30_9',
   'ToO_S250319bu_BBHBot_2025-05-01 01:25:32.352',
   'Twilight_2025-04-30_e',
   'Twilight_2025-04-30_m',
   'Twilight_2025-05-01_e',
   'Twilight_2025-05-01_m',
   'Twilight_2025-05-02_e',
   'Twilight_2025-05-02_m',
   'Twilight_2025-05-03_e',
   'Twilight_2025-05-03_m',
   'Twilight_2025-05-04_e',
   'Twilight_2025-05-04_m',
   'default',
   'fallback',
   'missed_obs']},
 'version': '1.4.0+fritz.86bd000'}

In [10]:
# get plan stats

from trigger_utils.trigger_utils import get_plan_stats

# get a specific plan
gcnevent_id=13342
queuename="S250319bu_BBHBot_2025-04-03 01:18:21.260"
stats = get_plan_stats(gcnevent_id, queuename, fritz_token, mode='')

Already submitted to queue
Total time: 2940, probability: 0.7988238394848364


In [1]:
import json
import logging
import time

import healpy as hp
import matplotlib.pyplot as plt
import numpy as np
import yaml
from astropy.time import Time
from tqdm import tqdm

# from nuztf.ampel import get_preprocessed_results
# from nuztf.api import api_name, api_skymap
# from nuztf.base_scanner import BaseScanner
# from nuztf.paths import CONFIG_DIR
# from nuztf.skymap import Skymap


In [None]:
def check_skymap_ZTF_coverage():
    percent_cov = 0.0
    return percent_cov

In [None]:
# get ZTF coverage of a skymap

class SkymapScanner(BaseScanner):
    default_fritz_group = 1563

    def __init__(
        self,
        event: str = None,
        rev: int = None,
        prob_threshold: float = 0.9,
        cone_nside: int = 64,
        output_nside: int | None = None,
        n_days: float = 3.0,  # By default, accept things detected within 72 hours of event time
        config: dict = None,
    ):
        self.logger = logging.getLogger(__name__)
        self.prob_threshold = prob_threshold
        self.n_days = n_days
        self.event = event
        self.prob_threshold = prob_threshold
        self.output_nside = output_nside

        if config:
            self.config = config
        else:
            config_path = CONFIG_DIR.joinpath("gw_run_config.yaml")
            with open(config_path) as f:
                self.config = yaml.safe_load(f)

        self.skymap = Skymap(
            event=self.event,
            rev=rev,
            prob_threshold=self.prob_threshold,
            output_nside=self.output_nside,
        )
        self.rev = self.skymap.rev

        self.t_min = Time(self.skymap.t_obs, format="isot", scale="utc")

        BaseScanner.__init__(
            self,
            run_config=self.config,
            t_min=self.t_min,
            cone_nside=cone_nside,
        )

        self.default_t_max = Time(self.t_min.jd + self.n_days, format="jd")
        self.logger.info(f"Time-range is {self.t_min} -- {self.default_t_max.isot}")

    def get_full_name(self):
        if self.skymap.event_name is not None:
            return self.skymap.event_name
        else:
            return "?????"

    def get_name(self) -> str:
        return f"{self.skymap.event_name}/{self.prob_threshold}"

    def download_results(self):
        """
        Retrieve computed results from the DESY cloud
        """
        self.logger.info("Retrieving results from the DESY cloud")
        file_basename = f"{self.skymap.event}_{self.skymap.rev}"

        res = get_preprocessed_results(file_basename=file_basename)

        if res is None:
            final_objects = []
        else:
            final_objects = [alert["objectId"] for alert in res]
            for alert in res:
                self.cache_candidates[alert["objectId"]] = alert

        final_objects = self.remove_duplicates(final_objects)

        self.logger.info(
            f"Retrieved {len(final_objects)} final objects for event "
            f"{self.get_name()} from DESY cloud."
        )

        self.add_results(res, cache=self.cache_candidates)


    def in_contour(self, ra, dec):
        """
        Whether a given coordinate is within the skymap contour

        :param ra: right ascension
        :param dec: declination
        :return: bool
        """
        return self.skymap.in_contour(ra, dec)


    def unpack_skymap(self, output_nside: None | int = None):
        """ """
        if output_nside is not None:
            self.skymap = Skymap(
                event=self.event,
                rev=self.rev,
                prob_threshold=self.prob_threshold,
                output_nside=output_nside,
            )

        nside = hp.npix2nside(len(self.skymap.data[self.skymap.key]))

        mask = self.skymap.data[self.skymap.key] > self.skymap.pixel_threshold

        max_pix = max(self.skymap.data[self.skymap.key])
        idx_max = list(self.skymap.data[self.skymap.key]).index(max_pix)

        ra, dec = self.extract_ra_dec(nside, idx_max)

        self.logger.info(f"hottest_pixel: {ra} {dec}")

        map_coords = []

        pixel_nos = []

        self.logger.info("Checking which pixels are within the contour:")

        for i in tqdm(range(hp.nside2npix(nside))):
            if mask[i]:
                map_coords.append(self.extract_ra_dec(nside, i))
                pixel_nos.append(i)

        total_pixel_area = hp.nside2pixarea(nside, degrees=True) * float(
            len(map_coords)
        )

        self.logger.info(f"Total pixel area: {total_pixel_area} degrees")

        map_coords = np.array(
            map_coords, dtype=np.dtype([("ra", float), ("dec", float)])
        )

        return (
            map_coords,
            pixel_nos,
            nside,
            self.skymap.data[self.skymap.key][mask],
            self.skymap.data,
            total_pixel_area,
            self.skymap.key,
        )

    def find_cone_coords(self):
        """ """

        cone_ids = []

        for ra, dec in self.map_coords:
            cone_ids.append(self.extract_npix(self.cone_nside, ra, dec))

        cone_ids = list(set(cone_ids))

        cone_coords = []

        for i in tqdm(cone_ids):
            cone_coords.append(self.extract_ra_dec(self.cone_nside, i))

        cone_coords = np.array(
            cone_coords, dtype=np.dtype([("ra", float), ("dec", float)])
        )

        return cone_ids, cone_coords

    def plot_skymap(self):
        """ """
        fig = plt.figure()
        plt.subplot(111, projection="aitoff")

        mask = self.data[self.key] > self.skymap.pixel_threshold

        size = hp.max_pixrad(self.nside, degrees=True) ** 2

        ra_map_rad = np.deg2rad(self.wrap_around_180(self.map_coords["ra"]))
        dec_map_rad = np.deg2rad(self.map_coords["dec"])

        plt.scatter(
            ra_map_rad,
            dec_map_rad,
            c=self.data[self.key][mask],
            vmin=0.0,
            vmax=max(self.data[self.key]),
            s=size,
        )

        plt.title("SKYMAP")

        outpath = self.get_output_dir().joinpath("skymap.png")
        plt.tight_layout()

        plt.savefig(outpath, dpi=300)

        return fig

    def plot_coverage(self, plot_candidates: bool = True, fields: list = None):
        """Plot ZTF coverage of skymap region"""
        fig, message = self.plot_overlap_with_observations(
            first_det_window_days=self.n_days,
            fields=fields,
        )

        if plot_candidates:
            for candidate, res in self.cache_candidates.items():
                ra = np.deg2rad(
                    self.wrap_around_180(np.array([res["candidate"]["ra"]]))
                )
                dec = np.deg2rad(res["candidate"]["dec"])

                plt.scatter(
                    ra, dec, color="white", marker="*", s=50.0, edgecolor="black"
                )

        plt.tight_layout()

        outpath = self.get_output_dir().joinpath("coverage.png")

        plt.savefig(outpath, dpi=300)

        self.logger.info(message)

        return fig, message