In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from dataclasses import dataclass

from lsst.daf.butler import Butler
import lsst.afw.image as afwImage
import lsst.geom as geom
from astroML.crossmatch import crossmatch_angular
from lsst.daf.butler import Butler
import numpy as np
from astropy.table import Table
from lsst.geom import Point2D
from lsst.pipe.tasks.calibrate import CalibrateTask
from lsst.pipe.tasks.characterizeImage import CharacterizeImageTask
from lsst.source.injection.inject_exposure import ExposureInjectTask
from lsst.meas.extensions.psfex.psfexPsfDeterminer import PsfexNoGoodStarsError
import h5py
import concurrent.futures
from multiprocessing import Lock, Semaphore, Manager
from astropy.io import ascii
import os
from multiprocessing import Value, Lock
import logging
import cv2

In [13]:
def psf_fit_flux_sigma(calexp, x, y):
    psf = calexp.getPsf()
    var_full = calexp.variance.array.astype(np.float64)

    p = psf.computeImage(geom.Point2D(x, y)).array.astype(np.float64)
    ph, pw = p.shape

    # center of stamp in full-image coords
    cx = int(round(x))
    cy = int(round(y))

    # stamp bounds in full image
    x0 = cx - pw//2
    y0 = cy - ph//2
    x1 = x0 + pw
    y1 = y0 + ph

    H, W = var_full.shape
    # clip to image bounds
    ix0 = max(x0, 0); iy0 = max(y0, 0)
    ix1 = min(x1, W); iy1 = min(y1, H)

    # corresponding bounds in PSF stamp coords
    px0 = ix0 - x0; py0 = iy0 - y0
    px1 = px0 + (ix1 - ix0); py1 = py0 + (iy1 - iy0)

    p_cut = p[py0:py1, px0:px1]
    v_cut = var_full[iy0:iy1, ix0:ix1]

    # normalize PSF cut to sum=1 (important if clipping happened)
    s = p_cut.sum()
    if not np.isfinite(s) or s <= 0:
        p = None
        v = None
    else:
        p = p_cut / s
        v = v_cut
    if p is None:
        return np.nan
    good = np.isfinite(v) & (v > 0) & np.isfinite(p)
    denom = np.sum((p[good]**2) / v[good])
    if denom <= 0:
        return np.nan
    return float(np.sqrt(1.0 / denom))

def mag_to_snr (mag, calexp, x, y):
    F = calexp.getPhotoCalib().magnitudeToInstFlux(mag)
    sigmaF = psf_fit_flux_sigma(calexp, x, y)
    snr = F / sigmaF
    return snr

def snr_to_mag(snr, calexp, x, y):
    sigmaF = psf_fit_flux_sigma(calexp, x, y)
    F = snr * sigmaF
    mag = calexp.getPhotoCalib().instFluxToMagnitude(F)
    return mag
    

def draw_one_line(mask, origin, angle, length, true_value=1, line_thickness=500):
    x0, y0 = origin
    x_size = length * np.cos((np.pi / 180) * angle)
    y_size = length * np.sin((np.pi / 180) * angle)
    x1 = x0 - x_size / 2
    y1 = y0 - y_size / 2
    x0 = x0 + x_size / 2
    y0 = y0 + y_size / 2
    one_line_mask = cv2.line(np.zeros(mask.shape), (int(x0), int(y0)), (int(x1), int(y1)), 1, thickness=line_thickness)
    mask[one_line_mask != 0] = true_value
    return mask

def characterizeCalibrate(postISRCCD):
    char_config = CharacterizeImageTask.ConfigClass()
    char_config.doApCorr = True
    char_config.doDeblend = True
    char_task = CharacterizeImageTask(config=char_config)
    char_result = char_task.run(postISRCCD)
    
    calib_config = CalibrateTask.ConfigClass()
    calib_config.doAstrometry = False
    calib_config.doPhotoCal = False
    calib_task = CalibrateTask(config=calib_config, icSourceSchema=char_result.sourceCat.schema)
    calib_result = calib_task.run(postISRCCD, background=char_result.background, icSourceCat=char_result.sourceCat)
    return calib_result.outputExposure, calib_result.sourceCat


def inject(postISRCCD, injection_catalog):
    inject_task = ExposureInjectTask()
    inject_res = inject_task.run([injection_catalog], postISRCCD, postISRCCD.psf, postISRCCD.photoCalib, postISRCCD.wcs)
    return inject_res.output_exposure


def generate_one_line(n_inject, trail_length, mag, beta, butler, ref, dimensions, source_type, seed, calexp=None):
    rng = np.random.default_rng(seed)
    injection_catalog = Table(
        names=('injection_id', 'ra', 'dec', 'source_type', 'trail_length', 'mag', 'beta', 'visit', 'detector',
               'integrated_mag', 'PSF_mag', 'SNR', 'physical_filter', 'x', 'y'),
        dtype=('int64', 'float64', 'float64', 'str', 'float64', 'float64', 'float64', 'int64', 'int64', 'float64',
               'float64', 'float64', 'str', 'int64', 'int64'))
    raw = butler.get(source_type + ".wcs", dataId=ref.dataId)
    info = butler.get(source_type + ".visitInfo", dataId=ref.dataId)
    filter_name = butler.get(source_type + ".filter", dataId=ref.dataId)
    fwhm = {"u": 0.92, "g": 0.87, "r": 0.83, "i": 0.80, "z": 0.78, "y": 0.76}
    m5 = {"u": 23.7, "g": 24.97, "r": 24.52, "i": 24.13, "z": 23.56, "y": 22.55}
    psf_depth = m5[filter_name.bandLabel]
    pixelScale = raw.getPixelScale().asArcseconds()
    theta_p = fwhm[filter_name.bandLabel] * pixelScale
    a, b = 0.67, 1.16
    for k in range(n_inject):
        x_pos = rng.uniform(1, dimensions.x - 1)
        y_pos = rng.uniform(1, dimensions.y - 1)
        sky_pos = raw.pixelToSky(x_pos, y_pos)
        ra_pos = sky_pos.getRa().asDegrees()
        dec_pos = sky_pos.getDec().asDegrees()
        inject_length = rng.uniform(*trail_length)
        x = inject_length / (24 * theta_p)
        upper_limit_mag = psf_depth - 1.25 * np.log10(1 + (a * x ** 2) / (1 + b * x)) if mag[1] == 0 else mag[1]
        magnitude = rng.uniform(mag[0], upper_limit_mag)
        surface_brightness = magnitude + 2.5 * np.log10(inject_length)
        psf_magnitude = magnitude + 1.25 * np.log10(1 + (a * x ** 2) / (1 + b * x))
        angle = rng.uniform(*beta)
        if calexp is not None:
            snr = mag_to_snr(psf_magnitude, calexp, x_pos, y_pos)
            injection_catalog.add_row([k, ra_pos, dec_pos, "Trail", inject_length, surface_brightness, angle, info.id,
                                       int(ref.dataId["detector"]), magnitude, psf_magnitude, snr, str(filter_name.bandLabel),
                                       x_pos, y_pos])
        else:
            injection_catalog.add_row([k, ra_pos, dec_pos, "Trail", inject_length, surface_brightness, angle, info.id,
                                       int(ref.dataId["detector"]), magnitude, psf_magnitude, str(filter_name.bandLabel),
                                       x_pos, y_pos])
    return injection_catalog

def stack_hits_by_footprints(
    post_src,
    calexp_pre,
    dimensions,
    truth_id_mask,
    injection_catalog,
    overlap_frac=0.02,
    overlap_minpix=100,
):
    H, W = int(dimensions.y), int(dimensions.x)
    N = len(injection_catalog)

    det_flag = np.zeros(N, bool)
    det_mag = np.full(N, np.nan)
    det_magerr = np.full(N, np.nan)
    det_snr = np.full(N, np.nan)
    matched_fp_masks = [np.zeros((H, W), bool) for _ in range(N)]

    mags = calexp_pre.photoCalib.instFluxToMagnitude(post_src, "base_PsfFlux")

    ys, xs = np.nonzero(truth_id_mask)
    ids = truth_id_mask[ys, xs] - 1

    pix_by_id = [[] for _ in range(N)]
    for y, x, i in zip(ys, xs, ids):
        if 0 <= i < N:
            pix_by_id[i].append((y, x))

    for inj_id in range(N):
        if not pix_by_id[inj_id]:
            continue

        pts = pix_by_id[inj_id]
        yy = np.array([p[0] for p in pts])
        xx = np.array([p[1] for p in pts])

        y0, y1 = yy.min(), yy.max()
        x0, x1 = xx.min(), xx.max()
        truth_count = len(pts)

        th = np.zeros((y1 - y0 + 1, x1 - x0 + 1), bool)
        for y, x in pts:
            th[y - y0, x - x0] = True

        best_ov, best_idx, best_fp = 0, None, None

        for idx in range(len(post_src)):
            fp = post_src[idx].getFootprint()
            n_pix_footprint = fp.getArea()
            required = max(overlap_minpix, int(overlap_frac * n_pix_footprint))
            bb = fp.getBBox()
            if bb.getEndX() < x0 or bb.getBeginX() > x1 or bb.getEndY() < y0 or bb.getBeginY() > y1:
                continue

            fm = np.zeros_like(th)
            ov = 0
            for span in fp.spans:
                y = span.getY()
                if y < y0 or y > y1:
                    continue
                sx0 = max(span.getX0(), x0)
                sx1 = min(span.getX1(), x1)
                if sx0 <= sx1:
                    fm[y - y0, sx0 - x0 : sx1 - x0 + 1] = True
                    ov = int((fm & th).sum())
                    if ov >= required:
                        break

            if ov > best_ov:
                best_ov, best_idx, best_fp = ov, idx, fm
                if ov >= required:
                    break

        if best_idx is not None and best_ov >= required:
            det_flag[inj_id] = True
            det_mag[inj_id] = mags[best_idx, 0]
            det_magerr[inj_id] = mags[best_idx, 1]
            f = float(post_src[best_idx].get("base_PsfFlux_instFlux"))
            fe = float(post_src[best_idx].get("base_PsfFlux_instFluxErr"))
            det_snr[inj_id] = f / fe if (np.isfinite(f) and np.isfinite(fe) and fe > 0) else np.nan
            matched_fp_masks[inj_id][y0:y1+1, x0:x1+1] |= best_fp

    injection_catalog["stack_detection"] = det_flag
    injection_catalog["stack_mag"] = det_mag
    injection_catalog["stack_mag_err"] = det_magerr
    injection_catalog["stack_snr"] = det_snr
    return injection_catalog, matched_fp_masks

def crossmatch_catalogs (pre, post):
    # Crossmatch POST vs PRE on-sky (inputs in radians, radius in radians)
    #    post-only detections -> "new" sources likely caused by injections
    if len(pre) > 0 and len(post) > 0:
        # arrays of [ra, dec] in radians
        P = post.asAstropy().to_pandas()[["coord_ra", "coord_dec"]].values
        R = pre.asAstropy().to_pandas()[["coord_ra", "coord_dec"]].values
        max_sep = np.deg2rad(0.40 / 3600.0)  # 0.40 arcsec -> radians
        dist, ind = crossmatch_angular(P, R, max_sep)
        is_new = np.isinf(dist)  # no match in PRE â†’ likely injected
        new_post = post[is_new].copy()
    else:
        new_post = post.copy()
    return new_post

def one_detector_injection(n_inject, trail_length, mag, beta, repo, coll, dimensions, source_type, ref_dataId, seed=123, debug=False):
    try:
        if seed is None:
            seed = np.random.randint(0,10000)
        butler = Butler(repo, collections=coll)
        ref = butler.registry.findDataset(source_type, dataId=ref_dataId)
        calexp = butler.get("preliminary_visit_image", dataId=ref.dataId)
        injection_catalog = generate_one_line(n_inject, trail_length, mag, beta, butler, ref, dimensions, source_type, seed, calexp)
        pre_injection_Src = butler.get("single_visit_star_footprints", dataId=ref.dataId)
        injected_calexp, post_injection_Src = characterizeCalibrate(inject(calexp, injection_catalog))
        mask = np.zeros((dimensions.y, dimensions.x), dtype=int)
        for i, row in enumerate(injection_catalog):
            psf_width = injected_calexp.psf.getLocalKernel(Point2D(row["x"], row["y"])).getWidth()
            mask = draw_one_line(mask, [row["x"], row["y"]], row["beta"], row["trail_length"], true_value=i + 1,
                                 line_thickness=int(psf_width/2))
        injection_catalog, matched_fp_mask = stack_hits_by_footprints(post_src=crossmatch_catalogs (pre_injection_Src, post_injection_Src),
                                                                       calexp_pre=calexp,
                                                                       dimensions=dimensions,
                                                                       truth_id_mask=mask,
                                                                       injection_catalog=injection_catalog,
                                                                       overlap_minpix=10,
                                                                       overlap_frac=0.5,)

        if not debug:
            return injected_calexp.image.array.astype("float32"), mask.astype("bool"), injection_catalog
        else:
            det_mask = None
            m = injected_calexp.mask
            if "DETECTED" in m.getMaskPlaneDict():
                det_bit = m.getPlaneBitMask("DETECTED")
                det_mask = (m.array & det_bit) != 0
            det_neg_mask = None
            if "DETECTED_NEGATIVE" in m.getMaskPlaneDict():
                detn_bit = m.getPlaneBitMask("DETECTED_NEGATIVE")
                det_neg_mask = (m.array & detn_bit) != 0
            matched_fp_masks = np.any(np.stack(matched_fp_mask, axis=-1), axis=-1)
            return injected_calexp.image.array.astype("float32"), mask.astype("bool"), injection_catalog, det_mask, matched_fp_masks
    except Exception as e:
        print(f"[WARNING] Skipping {ref_dataId} due to failure: {e}")
        return None

repo = "/repo/main"
collection = "LSSTComCam/runs/DRP/DP1/w_2025_17/DM-50530"
#where = "instrument='LSSTComCam' AND skymap='lsst_cells_v1' AND day_obs>=20241101 AND day_obs<=20241127 AND exposure.observation_type='science' AND band in ('u','g','r','i','z','y') AND (exposure not in (2024110600163, 2024110800318, 2024111200185, 2024111400039, 2024111500225, 2024111500226, 2024111500239, 2024111500240, 2024111500242, 2024111500288, 2024111500289, 2024111800077, 2024111800078, 2024112300230, 2024112400094, 2024112400225, 2024112600327))"
where = "instrument='LSSTComCam' AND skymap='lsst_cells_v1' AND day_obs>=20241101 AND day_obs<=20241127 AND exposure.observation_type='science' AND band in ('u','g','r','i','z','y') AND (exposure=2024112600184)"
butler = Butler(repo, collections=collection)
refs = list(set(butler.registry.queryDatasets("preliminary_visit_image", where=where, instrument="LSSTComCam", findFirst=True)))
refs = sorted(refs, key=lambda r: str(r.dataId["visit"]*1000+r.dataId["detector"]))

In [15]:
refs[4]

DatasetRef(DatasetType('preliminary_visit_image', {band, instrument, day_obs, detector, physical_filter, visit}, ExposureF), {instrument: 'LSSTComCam', detector: 4, visit: 2024112600184, band: 'r', day_obs: 20241126, physical_filter: 'r_03'}, run='LSSTComCam/runs/DRP/DP1/w_2025_17/DM-50530/20250425T225941Z', id=cd6db64a-ef04-459f-be83-1b7fb37f009a)

In [20]:
repo = "/repo/main"
collection = "LSSTComCam/runs/DRP/DP1/w_2025_17/DM-50530"
where = "instrument='LSSTComCam' AND skymap='lsst_cells_v1' AND day_obs>=20241101 AND day_obs<=20241127 AND exposure.observation_type='science' AND band in ('u','g','r','i','z','y') AND (exposure not in (2024110600163, 2024110800318, 2024111200185, 2024111400039, 2024111500225, 2024111500226, 2024111500239, 2024111500240, 2024111500242, 2024111500288, 2024111500289, 2024111800077, 2024111800078, 2024112300230, 2024112400094, 2024112400225, 2024112600327))"
where = "instrument='LSSTComCam' AND skymap='lsst_cells_v1' AND day_obs>=20241101 AND day_obs<=20241127 AND exposure.observation_type='science' AND band in ('u','g','r','i','z','y') AND (exposure=2024112400111)"

butler = Butler(repo, collections=collection)
refs = list(set(butler.registry.queryDatasets("preliminary_visit_image", where=where, instrument="LSSTComCam", findFirst=True)))
refs = sorted(refs, key=lambda r: str(r.dataId["visit"]*1000+r.dataId["detector"]))
dimensions = butler.get("preliminary_visit_image", dataId=refs[0].dataId).getDimensions()
injected_calexp, mask, injection_catalog = one_detector_injection(20, [1, 1], [23.8,23.9], [0,180], repo, collection, dimensions, "preliminary_visit_image", refs[4].dataId)



TypeError: cannot unpack non-iterable NoneType object

In [23]:
calexp = butler.get("preliminary_visit_image", dataId=refs[4].dataId)

In [26]:
calexp.wcs.getPixelScale()

AttributeError: 'NoneType' object has no attribute 'getPixelScale'