# Access PSF information from coadds

Created by: Miranda Gorsuch

This notebook  a brief example of accessing the PSF information from both default and cell-based coadds.

LSST Science Piplines version: Weekly 2025_15

Container Size: small (4 GB)

## Imports & Definitions

In [None]:
from lsst.daf.butler import Butler
from lsst.skymap import Index2D
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
import gc

%matplotlib inline

from lsst.skymap import Index2D
import lsst.afw.geom as afwGeom
import lsst.afw.math as afwMath
import lsst.geom as geom
from lsst.geom import Point2D
import lsst.meas.algorithms as meas

from lsst.afw.geom.ellipses import Quadrupole, SeparableDistortionTraceRadius

REPO = '/repo/main'

comcam_dataId = {'instrument': 'LSSTComCam', 
                 'skymap': 'lsst_cells_v1',}

## Default Coadds

In [None]:
collection = 'LSSTComCam/runs/DRP/DP1/w_2025_10/DM-49359'
butler = Butler(REPO, collections=[collection])
registry = butler.registry

Generate a list of available coadds in the collection

In [None]:
for ref in butler.registry.queryDatasets('deepCoadd', collections=collection, instrument=comcam_dataId['instrument'], skymap=comcam_dataId['skymap'], band='i'):
    print(ref.dataId)

Load in the example coadd information

In [None]:
# this loads in the entire coadd
# coadd = cell_butler.get('deepCoadd', 
#                          collections = collection, 
#                          instrument = 'LSSTComCam', 
#                          skymap = 'lsst_cells_v1', 
#                          tract = 10704, 
#                          patch = 5,
#                          band = 'i',)

# load in just the coadd PSF information
coadd_psf = butler.get('deepCoadd.psf', 
                         collections = collection, 
                         instrument = 'LSSTComCam', 
                         skymap = 'lsst_cells_v1', 
                         tract = 10704, 
                         patch = 5,
                         band = 'i',)

# load in the bbox information
coadd_bbox = butler.get('deepCoadd.bbox', 
                         collections = collection, 
                         instrument = 'LSSTComCam', 
                         skymap = 'lsst_cells_v1', 
                         tract = 10704, 
                         patch = 5,
                         band = 'i',)

Find an example point

In [None]:
coadd_bbox.getCorners()

In [None]:
# picking some arbitrary point within the corners
shape = coadd_psf.computeShape(Point2D(15000, 1000))

# various PSF values
trace_radius = shape.getTraceRadius()
i_xx, i_yy, i_xy = shape.getIxx(), shape.getIyy(), shape.getIxy()

q = Quadrupole(i_xx, i_yy, i_xy)
s = SeparableDistortionTraceRadius(q)

e1, e2 = s.getE1(), s.getE2()
theta = np.arctan2(e2, e1)

In [None]:
print(f"Trace radius: {trace_radius}")
print(f"Second moments (I_xx, I_yy, and I_xy): {i_xx}, {i_yy}, {i_xy}")
print(f"e1, e2: {e1}, {e2}")
print(f"Theta: {theta}")

For a more thorough tutorial on PSF functions, see https://nbviewer.org/github/LSSTScienceCollaborations/StackClub/blob/rendered/Validation/image_quality_demo.nbconvert.ipynb

## Cell-Based Coadds

Additional collections for cell-based coadds:
- ECDFS: `'u/mgorsuch/ComCam_Cells/ECDFS/20250217T221024Z'`
- EDFS:  `'u/mgorsuch/ComCam_Cells/EDFS/20250214T210850Z'`
- Rubin SV 95 -25: `'u/mgorsuch/ComCam_Cells/Rubin_SV_95_25/20250219T024719Z'`
- Fornax: `'u/mgorsuch/ComCam_Cells/fornax/20250219T025052Z'`
- 47_tuc: `'u/mgorsuch/ComCam_Cells/47_Tuc/20250219T025226Z'`

In [None]:
cell_collection = 'u/mgorsuch/ComCam_Cells/Rubin_SV_38_7/20250214T210230Z'
cell_butler = Butler(REPO, collections=[cell_collection])
cell_registry = cell_butler.registry

### For an individual cell within a patch

The `.psf` trick used in the default coadds is not available for cell-based coadds. Read the coadd of an example patch in.

In [None]:
cell_coadd = cell_butler.get('deepCoaddCell', 
                             collections = cell_collection, 
                             instrument = 'LSSTComCam', 
                             skymap = 'lsst_cells_v1', 
                             tract = 10704, 
                             patch = 5,
                             band = 'i',)

The warning above is indicating these coadds are slightly outdated and do not have aperture correction information.

In [None]:
# define one of the cells
test_cell = cell_coadd.cells[Index2D(x=2,y=15)]

In [None]:
# to see what cell indices are available run 
# printlist(cell_coadd.cells.keys())

Retrieve the PSF information from the cell. 

In [None]:
# retrieve the psf
psf_im = test_cell.psf_image

# define a PSF kernel from the PSF image
psf_kernel = afwMath.FixedKernel(psf_im)
psf = meas.KernelPsf(psf_kernel)
shape = psf.computeShape(psf_im.getBBox().getCenter())

# various useful functions
trace_radius = shape.getTraceRadius()
i_xx, i_yy, i_xy = shape.getIxx(), shape.getIyy(), shape.getIxy()

q = Quadrupole(i_xx, i_yy, i_xy)
s = SeparableDistortionTraceRadius(q)

e1, e2 = s.getE1(), s.getE2()
theta = np.arctan2(e2, e1)

### For cells across multiple tracts

In [None]:
'''
Retrieves the unique tract/patch combiations within a specified collection.

-- Inputs --

butler: Butler object
collection: the relevant collection containing cell-based coadds of interest
data_kwargs: dictionary of specific instrument and skymap used for butler query

-- Returns --

field_quanta: pandas DataFrame with columns for available tract and patch IDs within collection
'''
def get_field_info(butler, collection, data_kwargs):
    # field_quanta = [] # collection of tract, patch available in collection
    field_quanta = pd.DataFrame()
    tracts = []
    patches = []
    
    for ref in butler.registry.queryDatasets('deepCoaddCell',
                                                 band='i',
                                                 collections=collection,
                                                 instrument = data_kwargs['instrument'],
                                                 skymap = data_kwargs['skymap'],):
        
        tracts.append(ref.dataId.get('tract'))
        patches.append(ref.dataId.get('patch'))

    field_quanta['tract'] = tracts
    field_quanta['patch'] = patches

    return field_quanta

In [None]:
'''
Retrieve the total number of cells from your input collection

-- Inputs --

field_data: pandas DataFrame containing tract/patch combinations of interest
butler: Butler object
collection: the relevant collection containing cell-based coadds of interest
data_kwargs: dictionary of specific instrument and skymap used for butler query

-- Returns --

cell_count: number of cells with inputs in specified field data

NOTE: DOES include duplicate cells due to overlap of patches/tracts
'''
def get_cell_count(field_data, butler, collection, data_kwargs):

    cell_count = 0
    
    for patch_index, field in field_data.iterrows():
    
        coadd = butler.get('deepCoaddCell', 
                           collections=collection, 
                           instrument=data_kwargs['instrument'],
                           skymap = data_kwargs['skymap'],
                           tract=field['tract'], 
                           patch=field['patch'],
                           band='i',)
    
        cells = len(list(coadd.cells.keys())) # get number of non-empty cells
        cell_count += cells
        del coadd
        gc.collect()

    return cell_count

In [None]:
'''
Iterates through cells in each patch to collect cell PSF infromation in a DataFrame.

-- Inputs --

field_data: pandas DataFrame containing tract/patch combinations of interest
butler: Butler object
collection: the relevant collection containing cell-based coadds of interest
data_kwargs: dictionary of specific instrument and skymap used for butler query

-- Returns --

data_df: pandas DataFrame containing PSF information for each cell.
'''
def get_cell_data(field_data, butler, collection, data_kwargs):

    cell_num = get_cell_count(field_data, butler, collection, data_kwargs)
    # can set manually to avoid running above function again
    # cell_num = 56260 
    print('cell num: ', cell_num)
    print('cell count (done)')

    # define length of dataframe ahead of time, for speed
    data_df = pd.DataFrame(index=range(cell_num),
                           columns=['tract', 
                                    'patch', 
                                    'x_index', 
                                    'y_index', 
                                    'ra', 
                                    'dec', 
                                    'trace_radius', 
                                    'e1', 
                                    'e2', 
                                    'theta'])

    
    for patch_index, field in field_data.iterrows():
    
        coadd = butler.get('deepCoaddCell', 
                         collections=collection, 
                         instrument=data_kwargs['instrument'],
                         skymap = data_kwargs['skymap'], 
                         tract=field['tract'], 
                         patch=field['patch'],
                         band='i',)

        wcs = coadd.wcs
    
        cell_list = list(coadd.cells.keys()) # skips empty cell indices
    
        for cell_index in cell_list:
    
            cell = coadd.cells[cell_index]

            x_index = cell_index.x
            y_index = cell_index.y

            # collect cell center location
            # primarily used for removing duplicates due to patch overlap
            cell_center = cell.inner.bbox.getCenter()
            cell_center_coord = wcs.pixelToSky(cell_center)

            psf_im = cell.psf_image

            psf_kernel = afwMath.FixedKernel(psf_im)
            psf = meas.KernelPsf(psf_kernel)
            shape = psf.computeShape(psf_im.getBBox().getCenter())
            
            trace_radius = shape.getTraceRadius()
            i_xx, i_yy, i_xy = shape.getIxx(), shape.getIyy(), shape.getIxy()
            
            q = Quadrupole(i_xx, i_yy, i_xy)
            s = SeparableDistortionTraceRadius(q)
            
            e1, e2 = s.getE1(), s.getE2()
            theta = np.arctan2(e2, e1)

            data = [field['tract'],
                    field['patch'],
                    x_index, 
                    y_index, 
                    cell_center_coord[0].asDegrees(), 
                    cell_center_coord[1].asDegrees(),
                    trace_radius,
                    e1,
                    e2,
                    theta]

            data_df.loc[data_df[data_df.tract.isnull()].index[0]] = data
            
        coadd = 0
        # gc.collect()
        if (patch_index%5)==0:
            gc.collect()

    # drop extra unused rows, if any
    data_df = data_df.dropna()

    # calculate quantities derived from other columns
    data_df['e'] = np.sqrt(np.add(np.square(data_df['e1'].astype(np.float64)), np.square(data_df['e2'].astype(np.float64))))
    data_df['x_vec'] = data_df['e'] * np.cos(data_df['theta'].astype(np.float64))
    data_df['y_vec'] = data_df['e'] * np.sin(data_df['theta'].astype(np.float64))
    
    # remove overlapping cells due to patch overlap    
    data_df = data_df.drop_duplicates(subset=['ra', 'dec'])
    gc.collect()

    return data_df

Run the functions.

The `cell_count` part takes ~8 minutes, but for this example can be set to 56260 to skip running. The rest is ~9 minutes.

(may want to collapse output in a notebook)

In [None]:
field_info = get_field_info(cell_butler, cell_collection, comcam_dataId)
psf_df = get_cell_data(field_info, cell_butler, cell_collection, comcam_dataId)