In [None]:
import numpy as np
import matplotlib.pyplot as plt
from lsst.daf.butler import Butler
from lsst.summit.utils.plotting import plot
from lsst.obs.lsst import LsstCam
import lsst.afw.cameraGeom.utils as camGeomUtils
import lsst.afw.math as afwMath
import lsst.afw.display as afwDisplay
from lsst.afw import image
from lsst.geom import Point2I
from skimage.feature import hessian_matrix, hessian_matrix_eigvals
import cv2

In [None]:
butler = Butler('LSSTCam', collections=["LSSTCam/raw/all", "LSSTCam/calib", "LSSTCam/runs/quickLook"])
camera = LsstCam.getCamera()
instrument = "LSSTCam"

# First run one CCD as a test

In [None]:
dayObs = 202509827
seqNum = 654
expId = int(dayObs * 1E5 + seqNum)
detName = 'R41_S01'
det = camera[detName]
detNum = det.getId()
print(detNum)
calexp = butler.get('preliminary_visit_image', detector=detNum, visit=expId, instrument=instrument)

In [None]:
%matplotlib inline
x = plot(calexp, stretch='ccs')
x.savefig(f"/home/c/cslage/u/Satellites/images/LSSTCam_{expId}_{detNum}.png")
x

# The cell below does the streak finding work.

In [None]:
def find_faint_ridges(calexp, sigma=1.0, bin=4, edge=20, threshold=-0.05, 
                      aspect=8, streak_width=20, make_plots=False):
    arr = calexp.image.array    
    # Bin original image down to binxbin pixels
    arr = np.clip(arr, a_min=0, a_max=100)
    new_shape = (int(arr.shape[0] / bin), int(arr.shape[1] / bin))
    # Rebin by averaging
    bin_arr = arr.reshape(
        new_shape[0],
        arr.shape[0] // new_shape[0],
        new_shape[1],
        arr.shape[1] // new_shape[1]
    ).mean(-1).mean(1)
    # smooth small features
    blurred = cv2.medianBlur(bin_arr, 3)        
    
    # Use the Hessian matrix to find streaks
    # The minima ridges output has been most effective
    # in finding the streaks
    gauss = cv2.GaussianBlur(blurred, (11,11), 0)
    H_elems = hessian_matrix(gauss, sigma=sigma, order='rc', use_gaussian_derivatives=False)
    maxima_ridges, minima_ridges = hessian_matrix_eigvals(H_elems)
    # Now we create a binary image 
    # Setting this threshold has been difficult
    binary_ridges = minima_ridges < threshold
    binary_ridges = binary_ridges.astype(np.uint8)
    # Set edges of binary_ridges to zero
    binary_ridges[:,0:edge] = 0
    binary_ridges[:,-edge:-1] = 0
    binary_ridges[0:edge,:] = 0
    binary_ridges[-edge:-1,:] = 0
    # Convert to 0 -> 255
    _, binary = cv2.threshold(binary_ridges, 0.5, 255, cv2.THRESH_BINARY)
    # Find connected regions
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary, connectivity=8)
    # Sort to find regions with long aspect ratios
    long_labels = []
    for i in range(num_labels):
        mask = np.uint8(labels == i)
        # Extract points (x,y) of this component
        ys, xs = np.where(mask > 0)
        points = np.column_stack((xs, ys))
        rect = cv2.minAreaRect(points)
        (center, (width, height), angle) = rect
        if height > 0:
            aspect_ratio = max(width, height) / min(width, height)
        else:
            aspect_ratio = 0  # Handle division by zero for flat regions
        if aspect_ratio > aspect:
            long_labels.append(i)

    disp_img = np.zeros_like(bin_arr)
    rows, cols = binary.shape
    # Fit lines to the longest ones
    for label in long_labels:
        mask = np.uint8(labels == label)
        # Extract points (x,y) of this component
        ys, xs = np.where(mask > 0)
        points = np.column_stack((xs, ys))
        # Fit a line through the points
        [vx, vy, x0, y0] = cv2.fitLine(points, cv2.DIST_L2, 0, 0.01, 0.01)
        vx = vx[0]; vy = vy[0]; x0 = x0[0]; y0 = y0[0]
        # Weed out near horizontal or vertical lines
        if (abs(vx) < 0.1) or (abs(vy) < 0.1):
            continue
        # Define two endpoints for drawing
        lefty = int((-x0 * vy / vx) + y0)
        righty = int(((cols - x0) * vy / vx) + y0)
        # Add line to final image
        cv2.line(disp_img, (cols-1, righty), \
                 (0, lefty), (255,255,255), streak_width)
    # Resize back to original size (needed for showCamera)
    disp_img = cv2.resize(disp_img, (arr.shape[1], arr.shape[0]), interpolation=cv2.INTER_LINEAR)
    if make_plots:
        fig, axes = plt.subplots(2, 2, figsize=(10, 8))
        plt.subplots_adjust(hspace=0.3)
        plt.suptitle(f"Streak finding {expId}_{detNum}")
        ax = axes.ravel()
        ax[0].set_title('Original Image')
        ax[0].imshow(arr, cmap=plt.cm.gray, vmin=0, vmax=100, origin='lower')
        ax[1].set_title(f'Minima Ridges (sigma={sigma})\n Negative Hessian eigenvalues')
        ax[1].imshow(minima_ridges, origin='lower', vmin=-.01, vmax=.01)
        ax[2].set_title(f'Binary Ridges (sigma={sigma})\n Thresholding minima_ridges')
        ax[2].imshow(binary_ridges, origin='lower', cmap=plt.cm.gray, vmin=0, vmax=1)
        ax[3].set_title(f'Detected Streaks (sigma={sigma})\nAfter finding longest regions')
        ax[3].imshow(disp_img, cmap=plt.cm.gray, origin='lower')
    return [disp_img]


In [None]:
# Example usage:
[disp_img] = find_faint_ridges(calexp, sigma=12.0, threshold = -0.05, make_plots=True)
plt.savefig(f"/home/c/cslage/u/Satellites/streak_images/Hessian_Streak_Finding_{dayObs}_{seqNum}_{detNum}.png")

In [None]:
# This eliminates the heavily vignetted corners from the whole camera plot
rafts = [      'R01','R02','R03', 
         'R10','R11','R12','R13','R14',
         'R20','R21','R22','R23','R24',
         'R30','R31','R32','R33','R34',  
               'R41','R42','R43']
ccds = ['S00','S01','S02',
        'S10','S11','S12',
        'S20','S21','S22']
corners = ['R01_S00', 'R01_S01', 'R03_S01', 'R03_S02', \
                     'R10_S00', 'R10_S10', 'R30_S20', 'R30_S10', \
                    'R41_S20', 'R41_S21', 'R43_S21', 'R43_S22', \
                    'R34_S22', 'R34_S12', 'R14_S02', 'R14_S12']
detectorNameList = []
for raft in rafts:
    for ccd in ccds:
        name = raft+'_'+ccd
        if name not in corners:
            detectorNameList.append(name)

In [None]:
def streakCallback(im, ccd, imageSource):
    # This runs the streak finding algorithm on each CCD
    calexp = butler.get('preliminary_visit_image', detector=ccd.getId(), day_obs=dayObs, seq_num=seqNum)
    [disp_img] = find_faint_ridges(calexp, sigma=12.0, threshold=-0.05, streak_width=100)
    print(ccd.getId(), np.max(disp_img.flatten())) # This will flag which CCDs have streaks
    oim = image.ImageF(array=disp_img, deep=False, xy0=Point2I(0, 0))
    return oim

# This assembles the entire camera image

In [None]:
%matplotlib inline
instrument = "LSSTCam"
camera = butler.get('camera', instrument=instrument)
fig = plt.figure(figsize=(12,12))
disp = afwDisplay.Display(1, "matplotlib")
disp.scale('linear', min=0, max=100)
disp.setImageColormap("gray")
dataType='raw'
mos = camGeomUtils.showCamera(camera,
                              camGeomUtils.ButlerImage(butler, dataType, 
                                                       instrument=instrument, 
                                                       day_obs=dayObs, seq_num=seqNum,
                                                       verbose=False, callback=streakCallback,
                                                       background=np.nan),
                              detectorNameList=detectorNameList,
                              binSize=16, display=disp, overlay=False,
                              title="%d %d" % (dayObs, seqNum))

plt.savefig(f"/home/c/cslage/u/Satellites/streak_images/Streak_Finding_{dayObs}_{seqNum}.png")