In [None]:
import numpy as np

from astropy.coordinates import SkyCoord
from astropy.cosmology import Planck18 as cosmo
from astropy.io import fits
from astropy.table import Table
import astropy.units as u

import lsst.afw.display as afwDisplay
from lsst.daf.butler import Butler
from lsst.geom import degrees, SpherePoint

In [None]:
repo = "/repo/embargo_new"
collection = "u/rea3/test_0725J_v10"
# collection = "LSSTCam/raw/all"
exposure = 2025072800100

butler = Butler(repo, collections=collection)

In [None]:
dataref_iterator = butler.registry.queryDatasets(datasetType="difference_image", collections=[collection])

In [None]:
data_refs = [dr for dr in dataref_iterator.expanded()]

In [None]:
# We're mixing columns from dia_src and dia_src_table
# This is dangerous.  I don't think there is a guarantee that the ordering is the same.
# But I want the flags from dia_src and the science flux from dia_src_table

"""
good = ~dia_src["slot_Shape_flag"] & \
    (dia_src["base_PsfFlux_instFlux"] / dia_src["base_PsfFlux_instFluxErr"] > snr_threshold) & \
    ~dia_src["base_PixelFlags_flag_edge"] & \
    ((dia_src_table["scienceFlux"] / dia_src_table["scienceFluxErr"]) < max_science_snr) & \
    ~dia_src_table["pixelFlags_streak"]
"""

def good_src(cat, snr_threshold=7.5, max_science_snr=200):
    """Good for the dia_source_unfiltered"""
    good = ~cat["slot_Shape_flag"] & \
        (cat["base_PsfFlux_instFlux"] / cat["base_PsfFlux_instFluxErr"] > snr_threshold) & \
        ~cat["base_PixelFlags_flag_edge"]

    good_cat = cat[good].copy()

    return good_cat


def good_src_table(cat, snr_threshold=7.5, max_science_snr=200):
    """Good for the dia_source_table"""

    # If I were doing just dia_src_table, I'd do something like this:
    good = (cat["snr"] > snr_threshold) & \
        ~cat["shape_flag"] & \
        ~cat["pixelFlags_bad"] & \
        ~cat["pixelFlags_cr"] & \
        ((cat["scienceFlux"] / cat["scienceFluxErr"]) < max_science_snr)

    good_cat = cat[good].copy(deep=True)

    return good_cat

In [None]:
# Load NED-GWF table
url = "https://ned.ipac.caltech.edu/uri/NED::GWFglist/fits/S250725j/3"

ned_gwf = Table(fits.getdata(url))

In [None]:
ned_gwf

Match dia_src table to this set of galaxies.  5 arcmin

In [None]:
def load_and_match_dia_src_against_cat(dr, cat, dataset_type="dia_source_unfiltered", radius=30 * u.arcsec, debug=False):
    dia_src = butler.get(dataset_type, dataId=dr.dataId).asAstropy()
    good_dia_src = good_src(dia_src)
    
    dia_coord = SkyCoord(good_dia_src["coord_ra"], good_dia_src["coord_dec"])
    cat_coord = SkyCoord(cat["ra"] * u.deg, cat["dec"] * u.deg)
    
    idx, sep2d, _ = dia_coord.match_to_catalog_sky(cat_coord)

    close_enough = sep2d < radius
    matching_idx = idx[close_enough]

    if debug:
        print(len(dia_coord))
        print(cat_coord[364])
        print(idx)
        print(sep2d)
        print(matching_idx)
        print("----")

    matched_dia_src = good_dia_src[close_enough].copy()
    matched_cat = cat[matching_idx].copy()

    return matched_dia_src, matched_cat
        

In [None]:
Planck18.arcsec_per_kpc_proper

In [None]:
# Use an approximate matching radius given redshift.
z = 0.1
matching_radius_physical = 50 * u.kpc
matching_radius_angular = matching_radius_physical * cosmo.arcsec_per_kpc_proper(z)
print(matching_radius_physical, matching_radius_angular)

In [None]:
# So let's call that 30"
matching_radius = 30 * u.arcsec

In [None]:
dr_matches = []
dia_src_matches = []
cat_matches = []

for dr in data_refs:
    matched_dia_src, matched_cat = load_and_match_dia_src_against_cat(dr, ned_gwf, radius=radius)
    dr_matches.append(dr)
    dia_src_matches.append(matched_dia_src)
    cat_matches.append(matched_cat)



In [None]:
from astropy.table import vstack

dr_match = [dr for dr, dsm in zip(dr_matches, dia_src_matches) if len(dsm) > 0]
dia_src_match = vstack(dia_src_matches)
cat_match = vstack(cat_matches)

In [None]:
dr_match[0].dataId

In [None]:
dia_src_match

In [None]:
dia_src_match.write("dia_src_match_30arcsec.ecsv")

In [None]:
cat_match

In [None]:
i = 6
data_id = dr_match[i].dataId

In [None]:
template = butler.get("template_matched", dataId=data_id)
science = butler.get("preliminary_visit_image", dataId=data_id)
diffim = butler.get("difference_image", dataId=data_id)

In [None]:
dia_src = butler.get("dia_source_unfiltered", dataId=data_id).asAstropy()
good_dia_src = good_src(dia_src)


In [None]:
afwDisplay.setDefaultBackend("firefly")

In [None]:
transparency = 100

In [None]:
afw_display = afwDisplay.Display(frame=1)
afw_display.setMaskTransparency(transparency)
afw_display.scale("asinh", -20, 500)
# afw_display.scale("linear", "zscale")

afw_display.mtv(template)

In [None]:
afw_display = afwDisplay.Display(frame=2)
afw_display.setMaskTransparency(transparency)
afw_display.scale("asinh", -20, 50)
# afw_display.scale("linear", "zscale")

afw_display.mtv(science)

In [None]:
afw_display = afwDisplay.Display(frame=3)
afw_display.setMaskTransparency(transparency)
# afw_display.scale("asinh", -20, 50)
afw_display.scale("linear", "zscale")

afw_display.mtv(diffim.maskedImage)

In [None]:
# https://dp1.lsst.io/tutorials/notebook/103/notebook-103-5.html
def plot_points_on_image(xy, size=20, ctype="orange"):
    with afw_display.Buffering():
        for xi, yi in xy:
            afw_display.dot('o', xi, yi, size=size, ctype=ctype)

In [None]:
plot_points_on_image(zip(good_dia_src["slot_Centroid_x"], good_dia_src["slot_Centroid_y"]), ctype="green")

In [None]:
wcs = diffim.getWcs()
ned_gwf_coord = [SpherePoint(r*degrees, d*degrees) for r, d in zip(ned_gwf["ra"], ned_gwf["dec"])]
ned_gwf_xy = wcs.skyToPixel(ned_gwf_coord)

In [None]:
plot_points_on_image(ned_gwf_xy, ctype="orange")

# These blue circles are wront for some reason I don't understand yet
# plot_points_on_image(zip(dia_src_match["slot_Centroid_x"][i:i+1], dia_src_match["slot_Centroid_y"][i:i+1]), ctype="blue")