In [None]:
import numpy as np
import matplotlib.pyplot as plt
from lsst.daf.butler import Butler
from lsst.summit.utils.plotting import plot
from lsst.obs.lsst import LsstCam
import lsst.afw.cameraGeom.utils as camGeomUtils
import lsst.afw.math as afwMath
import lsst.afw.display as afwDisplay
from lsst.afw import image
from lsst.geom import Point2I
from StreakFinder import find_faint_ridges
from lsst.meas.algorithms.maskStreaks import Line, LineProfile, LineCollection

In [None]:
butler = Butler('/repo/embargo', collections=['LSSTCam/raw/all', 
                                            'LSSTCam/calib/unbounded', 'LSSTCam/runs/nightlyValidation'])
camera = LsstCam.getCamera()
instrument = "LSSTCam"

# First run one CCD as a test

In [None]:
#dayObs = 20250909
#seqNum = 313
dayObs = 20250915
seqNum = 319
detNum = -1

expId = int(dayObs * 1E5 + seqNum)
# Can specify detNum or detName
#detNum = 106
if detNum < 0:
    detName = 'R20_S22'
    det = camera[detName]
    detNum = det.getId()
print(detNum)
calexp = butler.get('preliminary_visit_image', detector=detNum, visit=expId, instrument=instrument)

In [None]:
%matplotlib inline
x = plot(calexp, stretch='ccs')
x

# Now we test the code with the different output options

In [None]:
lines = find_faint_ridges(calexp, output="L")
print(lines)

In [None]:
%matplotlib inline
fig = find_faint_ridges(calexp, output="P", threshold=-0.006, bin=2, sigma=12, kernel=11, aspect=8)
fig.savefig(f"/home/c/cslage/u/Satellites/streak_images/Hessian_Streak_Finding_{dayObs}_{seqNum}_{detNum}.png")

In [None]:
import astropy.units as u
def getLineXY(line, exp):
        """Return the pixel coordinates of the ends of the line.

        Parameters
        ----------
        line : `Line`
            Line for which to find the endpoints.

        Returns
        -------
        boxIntersections : `np.ndarray`
            (x, y) coordinates of the start and endpoints of the line.
        """
        _ymax, _xmax = exp.image.array.shape
        theta = line.theta * u.deg
        # Determine where the line intersects with each edge of the bounding
        # box.
        # Bottom:
        yA = -_ymax / 2.
        xA = (line.rho - yA * np.sin(theta)) / np.cos(theta)
        # Left:
        xB = -_xmax / 2.
        yB = (line.rho - xB * np.cos(theta)) / np.sin(theta)
        # Top:
        yC = _ymax / 2.
        xC = (line.rho - yC * np.sin(theta)) / np.cos(theta)
        # Right:
        xD = _xmax / 2.
        yD = (line.rho - xD * np.cos(theta)) / np.sin(theta)
        lineIntersections = np.array([[xA, yA],
                                      [xB, yB],
                                      [xC, yC],
                                      [xD, yD]])
        lineIntersections[:, 0] += _xmax / 2.
        lineIntersections[:, 1] += _ymax / 2.
        # The line will necessarily intersect with exactly two edges of the
        # bounding box itself.
        inBox = ((lineIntersections[:, 0] >= 0) & (lineIntersections[:, 0] <= _xmax)
                 & (lineIntersections[:, 1] >= 0) & (lineIntersections[:, 1] <= _ymax))
        boxIntersections = lineIntersections[inBox]

        return boxIntersections

In [None]:
boxIntersections = getLineXY(lines[0], calexp)
print(boxIntersections)

In [None]:
test = Line(800, 70)
boxIntersections = getLineXY(test, calexp)
print(boxIntersections)

In [None]:
arr = calexp.image.array
weights = np.ones_like(arr, dtype=bool)
line = lines[0]
line.sigma = 2.0
lineModel = LineProfile(arr, weights, line=line)#, detectionMask=detectionMask)

In [None]:
plt.imshow(lineModel.lineMask, origin='lower')

In [None]:
plt.plot(lineModel.lineMask[3000, 2100:2500])

In [None]:
fit, fitFailure = lineModel.fit()
print(fit.rho, fit.theta)

In [None]:
finalModel = lineModel.makeProfile(fit)

In [None]:
plt.imshow(finalModel, origin='lower', vmin=0, vmax=10)

In [None]:
plt.title(f"Streak profile {dayObs} {seqNum} {detNum}")
plt.plot(finalModel[3000, :], label='Fit')
plt.plot(arr[3000, :], alpha=0.5, label='Data')

plt.xlim(2100, 2400)
plt.ylim(-100,500)
plt.xlabel("X (pixels)")
plt.ylabel("Flux (electrons)")
plt.legend()
plt.savefig(f"/home/c/cslage/u/Satellites/streak_images/Line_Profile_{dayObs}_{seqNum}_{detNum}.png")

In [None]:
plt.plot(finalModel[500, 0:1000])

In [None]:
prof.getLineXY(lines[0])

In [None]:
prof.lineMaskSize = 20

In [None]:
prof.setLineMask(lines[0], maxStreakWidth=20, nSigmaMask=10)

In [None]:
mask = prof.lineMask
%matplotlib inline
plt.imshow(mask, vmin=0, vmax=1, origin='lower')

In [None]:
%matplotlib inline
plt.imshow(prof.data, cmap=plt.cm.gray, vmin=0, vmax=100, origin='lower')

In [None]:
model, dModel = prof._makeMaskedProfile(lines[0])

In [None]:
print(model.shape)

In [None]:
print(np.min(prof), np.max(prof))

# The cells below run the streak finding on the whole focal plane.

In [None]:
# This eliminates the heavily vignetted corners from the whole camera plot
rafts = [      'R01','R02','R03', 
         'R10','R11','R12','R13','R14',
         'R20','R21','R22','R23','R24',
         'R30','R31','R32','R33','R34',  
               'R41','R42','R43']
ccds = ['S00','S01','S02',
        'S10','S11','S12',
        'S20','S21','S22']
corners = ['R01_S00', 'R01_S01', 'R03_S01', 'R03_S02', \
                     'R10_S00', 'R10_S10', 'R30_S20', 'R30_S10', \
                    'R41_S20', 'R41_S21', 'R43_S21', 'R43_S22', \
                    'R34_S22', 'R34_S12', 'R14_S02', 'R14_S12']
detectorNameList = []
for raft in rafts:
    for ccd in ccds:
        name = raft+'_'+ccd
        if name not in corners:
            detectorNameList.append(name)

In [None]:
def streakCallback(im, ccd, imageSource):
    # This runs the streak finding algorithm on each CCD
    calexp = butler.get('preliminary_visit_image', detector=ccd.getId(), day_obs=dayObs, seq_num=seqNum)
    disp_img = find_faint_ridges(calexp, output="I")
    print(ccd.getId(), np.max(disp_img.flatten())) # This will flag which CCDs have streaks
    oim = image.ImageF(array=disp_img, deep=False, xy0=Point2I(0, 0))
    return oim

# This assembles the entire camera image

In [None]:
%matplotlib inline
instrument = "LSSTCam"
camera = butler.get('camera', instrument=instrument)
fig = plt.figure(figsize=(12,12))
disp = afwDisplay.Display(1, "matplotlib")
disp.scale('linear', min=0, max=100)
disp.setImageColormap("gray")
dataType='raw'
mos = camGeomUtils.showCamera(camera,
                              camGeomUtils.ButlerImage(butler, dataType, 
                                                       instrument=instrument, 
                                                       day_obs=dayObs, seq_num=seqNum,
                                                       verbose=False, callback=streakCallback,
                                                       background=np.nan),
                              detectorNameList=detectorNameList,
                              binSize=16, display=disp, overlay=False,
                              title="%d %d" % (dayObs, seqNum))