In [None]:
### Imports

import numpy as np
from astropy.io import fits
from astropy.stats import sigma_clip
from scipy.ndimage import label, maximum_position, shift, center_of_mass
import matplotlib.pyplot as plt
from scipy.ndimage import maximum_filter

## Helper Functions

In [None]:
### Function to detect sources; returns list of positions of maxima of potential sources

def detect_sources(image, threshold):
    mask = (image > threshold)
    labeled, nfeat = label(mask)
    if nfeat < 1:
        return []
    positions = maximum_position(image, labels=labeled,
                                 index=np.arange(1, nfeat+1))
    return positions

In [None]:
### Function to filter out sources that have other close sources within some number of pixels, as set by the user in variable 'min_sep'
### Eliminates binary star pairs and other sources contaminated by nearby sources

def remove_close_pairs(src_positions, min_sep=12):
    filtered = []
    for i, (x1, y1) in enumerate(src_positions):
        too_close = False
        for j, (x2, y2) in enumerate(src_positions):
            if i != j:
                dist = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
                if dist < min_sep:
                    too_close = True
                    break
        if not too_close:
            filtered.append((x1, y1))
    return filtered

In [None]:
### Function to measure radial brightness profile of source cutouts; only considers data within bounds of 'max_radius' as set by user

def radial_profile(stamp, center_y, center_x, max_radius=15):
    
    h, w = stamp.shape
    y, x = np.indices((h, w))
    r = np.sqrt((x - center_x)**2 + (y - center_y)**2)
    
    r_int = r.astype(int)
    
    profile = []
    radii = np.arange(max_radius + 1)
    
    for rad in radii:
        mask = (r_int == rad)
        if not np.any(mask):
            profile.append(0.0)
        else:
            vals = stamp[mask]
            profile.append(np.mean(vals))
    
    return radii, np.array(profile)

In [None]:
### Function to verify that sources generally slope downward from center as expected, with no profile increase greater than the 'max_up_fraction' as set by user

def is_monotonically_decreasing(profile, max_up_fraction=1.05):
    for i in range(1, len(profile)):
        prev_val = profile[i-1]
        curr_val = profile[i]
        if (curr_val > max_up_fraction * prev_val):
            return False
    return True

In [None]:
### Function that returns the minimum, nonzero background value for a source/PSF cutout

def measure_min_nonzero_background_entire(cutout):
    valid = cutout[cutout > 0]  
    if valid.size == 0:
        return 0
    else:
        return np.min(valid)

In [None]:
### Function to mask out local companions surrounding the main star in each source; uses a threshold based on 6 times the local background average

def mask_local_maxima_around_companion(stamp,
                                       main_star=None,
                                       background_corner_size=10,
                                       mask_radius=3,
                                       local_max_size=3):
    
    stamp_masked = stamp.copy()
    
    if main_star is None:
        ypeak, xpeak = maximum_position(stamp_masked)
        main_star = (ypeak, xpeak)
    else:
        ypeak, xpeak = main_star

    avg_bg = measure_cutout_corners_background(stamp_masked, corner_size=background_corner_size)
    companion_threshold = 6.0 * avg_bg
    if companion_threshold == 0:
        companion_threshold = 0.1

    local_max = maximum_filter(stamp_masked, size=local_max_size)
    peaks = (stamp_masked == local_max) & (stamp_masked >= companion_threshold)
    
    peak_coords = np.argwhere(peaks)
    
    h, w = stamp_masked.shape
    yy, xx = np.indices((h, w))
    
    for (py, px) in peak_coords:
        if (py, px) != (ypeak, xpeak):
            rr = np.sqrt((xx - px)**2 + (yy - py)**2)
            companion_area = (rr <= mask_radius)
            stamp_masked[companion_area] = np.nan

    return stamp_masked

In [None]:
### Function to calculate a median local background for a source/PSF cutout, using the data from the corners of the cutout (corner size set by user)

def measure_cutout_corners_background(cutout, corner_size=10):
    h, w = cutout.shape
    
    corner_tl = cutout[:corner_size, :corner_size].ravel()
    corner_tr = cutout[:corner_size, w-corner_size:].ravel()
    corner_bl = cutout[h-corner_size:, :corner_size].ravel()
    corner_br = cutout[h-corner_size:, w-corner_size:].ravel()

    corners = np.concatenate([corner_tl, corner_tr, corner_bl, corner_br])
    bg = np.median(corners)
    
    return bg

In [None]:
### Function to align the PSF with the position of the main star in the source cutout

def shift_psf_to_star(psf, star_x, star_y, cutout_size):
    cy = cutout_size // 2
    cx = cutout_size // 2
    shift_y = star_x - cy
    shift_x = star_y - cx
    shifted = shift(psf, (shift_y, shift_x), order=1, mode='constant', cval=0.0)
    return shifted

In [None]:
### Function to fit the PSF model to each source cutout using a least-squares computational approach
### Includes a constant background offset value (currently 1.0) that can be adjusted by the user as needed

def linear_amplitude_fit(cutout, star_positions, psf, delta=5.0, max_iter=10, tol=1e-6):
    h, w = cutout.shape
    npix = h * w 
    nstar = len(star_positions) 
    
    data_flat = cutout.ravel() 
    
    
    A = np.zeros((npix, nstar+1), dtype=np.float64)
    shifted_psfs = []
    
    for i, (sx, sy) in enumerate(star_positions):
        shifted_p = shift_psf_to_star(psf, sx, sy, h)
        shifted_psfs.append(shifted_p)
        A[:, i] = shifted_p.ravel()
    
    # constant background offset for subtraction; adjust as needed
    A[:, -1] = 1.0
    
    weights = np.ones(npix) 
    x_old = None

    for iteration in range(max_iter):
        W_sqrt = np.sqrt(weights)
        A_weighted = A * W_sqrt[:, np.newaxis]
        b_weighted = data_flat * W_sqrt
        
        x, residuals, rank, svals = np.linalg.lstsq(A_weighted, b_weighted, rcond=None)

        r = data_flat - A @ x
        
        new_weights = np.where(np.abs(r) <= delta, 1.0, delta / np.abs(r))
        
        if x_old is not None and np.linalg.norm(x - x_old) < tol:
            break
        x_old = x.copy()
        weights = new_weights
    
    model_img = np.zeros_like(cutout, dtype=np.float64)
    for i in range(nstar):
        model_img += x[i] * shifted_psfs[i]
    model_img += x[-1]
    
    return x, model_img

In [None]:
### Function to mask out the cores of companion stars within a source cutout
### Mask radius for companions is set by the user in 'mask_radius'

def companion_mask(cutout_shape, star_positions, main_star_index=0, mask_radius=4):
    mask = np.ones(cutout_shape, dtype=bool)
    for i, (sx, sy) in enumerate(star_positions):
        if i == main_star_index:
            continue 
        x_min = max(0, int(sx) - mask_radius)
        x_max = min(cutout_shape[0], int(sx) + mask_radius+1)
        y_min = max(0, int(sy) - mask_radius)
        y_max = min(cutout_shape[1], int(sy) + mask_radius+1)
        mask[x_min:x_max, y_min:y_max] = False
    return mask

## Main PSF Refinement Function

In [None]:
def refine_psf_from_mosaic(mosaic_path,
                           output_dir,
                           cutout_size=105,
                           num_stars=200,
                           threshold_sigma=4.0,
                           min_sep=5,
                           sat_limit=80,
                           iterations=2,
                           alpha=0.001,
                           sigma_clip_level=10.0,
                           mask_radius=2,
                           save_residuals=True):

    os.makedirs(output_dir, exist_ok=True)

    # Load Data
    with fits.open(mosaic_path) as hdul:
        data = hdul[0].data

    # Detect and Filter Sources
    mean_val = np.mean(data)
    med_val = np.median(data)
    std_val = np.std(data)
    detect_thresh = med_val + threshold_sigma * std_val
    
    all_positions = detect_sources(data, detect_thresh)
    print(f"Detected {len(all_positions)} sources above threshold={detect_thresh:.2f}.")

    # Sort by Brightness
    sorted_positions = sorted(all_positions,
                              key=lambda pos: data[pos[0], pos[1]],
                              reverse=True)

    # Remove Close Pairs (Binary Pairs)
    isolated = remove_close_pairs(sorted_positions, min_sep=min_sep)

    filtered_positions = []
    half = cutout_size // 2

    for (sx, sy) in isolated:
        if (sx - half < 0 or sx + half >= data.shape[0] or
            sy - half < 0 or sy + half >= data.shape[1]):
            continue
            
        sub = data[sx-half : sx+half+1, sy-half : sy+half+1]

        cy, cx = sub.shape[0] // 2, sub.shape[1] // 2
        yy, xx = np.indices(sub.shape)
        rr = np.sqrt((xx - cx)**2 + (yy - cy)**2)
        core_mask = (rr <= 10) 

        if np.any(sub[core_mask] == 0):
            continue

        center_val = sub[cy, cx]
        if center_val < np.max(sub):
            continue

        com = center_of_mass(sub)
        shift_y = half - com[0]
        shift_x = half - com[1]
        sub_centered = shift(sub, (shift_y, shift_x),
                             order=1, mode='constant', cval=0.0)

        # Measure Radial Profile
        stamp = sub_centered
        cy2, cx2 = stamp.shape[0] // 2, stamp.shape[1] // 2
        radii, prof = radial_profile(stamp, cy2, cx2)
    
        if not is_monotonically_decreasing(prof, max_up_fraction=1.05):
            continue

        filtered_positions.append((sx, sy))

    # Select Initial List of 'num_star' Sources
    picked = filtered_positions[:num_stars]

    half = cutout_size // 2
    cutouts = []
    
    for (sx, sy) in picked:
        if (sx-half < 0 or sx+half >= data.shape[0] or
            sy-half < 0 or sy+half >= data.shape[1]):
            continue
        
        sub = data[sx-half:sx+half+1, sy-half:sy+half+1]
        if np.max(sub) >= sat_limit:
            continue
        
        flux_sub = np.sum(sub)
        if flux_sub <= 0:
            continue
        
        # Re-center Source on Center of Mass
        com = center_of_mass(sub)
        shift_y = half - com[0]
        shift_x = half - com[1]
        sub_centered = shift(sub, (shift_y, shift_x),
                             order=1, mode='constant', cval=0.0)
        
        flux_centered = np.sum(sub_centered)
        if flux_centered <= 0:
            continue
        
        stamp = sub_centered.copy()
        cutouts.append(stamp)

    print(f"Final re-centered cutouts used: {len(cutouts)}")

    # Save Each Cutout to a Separate FITS File in Output Directory
    for i, c in enumerate(cutouts, start=1):
        cutout_file = os.path.join(output_dir, f"initial_cutout_{i}.fits")
        fits.PrimaryHDU(c).writeto(cutout_file, overwrite=True)


    # Begin Process to Build Initial PSF
    if len(cutouts) == 0:
        print("No valid cutouts found. Exiting.")
        return None

    masked_cutouts = []
    i = 0
    for stamp in cutouts:

        bg_min = measure_min_nonzero_background_entire(stamp)
        stamp_bsub = np.clip(stamp - bg_min, 0, None) 
    
        # Mask Companion Sources Surrounding Main Star
        stamp_masked = mask_local_maxima_around_companion(stamp_bsub, 
                                                          main_star=None,
                                                          background_corner_size=10,
                                                          mask_radius=3,
                                                          local_max_size = 3)
                                                            
        masked_cutouts.append(stamp_masked)

        # Save Masked Cutouts to Separate FITS Files in Output Directory
        stampmask_file = os.path.join(output_dir, f"stampmask{i+1}.fits")
        fits.PrimaryHDU(stamp_masked).writeto(stampmask_file, overwrite=True)
        i += 1

    # Build Initial PSF from Masked Cutouts
    stack = np.stack(masked_cutouts, axis=0)
    psf_init = np.nanmedian(stack, axis=0)
    psf_init = np.clip(psf_init, 0, None)

    init_sum = np.nansum(psf_init)
    if init_sum > 0:
        psf_init /= init_sum

    print(f"Initial PSF => sum={psf_init.sum():.3f}, max={psf_init.max():.3f}.")

    psf_current = psf_init.copy()
    fits.PrimaryHDU(psf_current).writeto(os.path.join(output_dir, "psf_initial.fits"), overwrite=True)

    # PSF Residual Iterative Refinement
    for it in range(iterations):
        print(f"\n--- Iteration {it+1}/{iterations} ---")
        residuals = []
        masks = []

        for c_idx, stamp in enumerate(cutouts):

            bg_min = measure_min_nonzero_background_entire(stamp)
            
            stamp_bsub = stamp - bg_min
            
            stamp_bsub = np.clip(stamp_bsub, 0, None)
            
            peak_cut = stamp_bsub.max()

            if peak_cut <= 0:
                resid = stamp_bsub.copy()
                mask = np.ones_like(stamp, dtype=bool)
            else:
                # Identify Main Star
                main_detect_thr = 0.3 * peak_cut
                main_star_mask = (stamp_bsub > main_detect_thr)
                lbl_main, n_main = label(main_star_mask)
                if n_main < 1:
                    resid = stamp_bsub.copy()
                    mask = np.ones_like(stamp, dtype=bool)
                else:
                    loc_pos = maximum_position(stamp_bsub, labels=lbl_main,
                                               index=np.arange(1, n_main+1))
                    # Sort by Decreasing Brightness
                    loc_pos_sorted = sorted(loc_pos,
                                            key=lambda xy: stamp_bsub[xy[0], xy[1]],
                                            reverse=True)
                    main_star = loc_pos_sorted[0]
                    #print("Main star detected at:", main_star)
        
                    # Calculate Average Background
                    avg_bg = measure_cutout_corners_background(stamp, corner_size=10)
                    #print("Average background (from corners):", avg_bg)
        
                    companion_thr = 6 * avg_bg
                    #print("Companion threshold (6x avg_bg):", companion_thr)
        
                    companion_binary = stamp_bsub > companion_thr
        
                    lbl_comp, n_comp = label(companion_binary)
        
                    # Build Companion Mask
                    final_mask = np.ones_like(stamp, dtype=bool)
                    for comp_idx in range(1, n_comp+1):
                        region = (lbl_comp == comp_idx)
                        if not region[main_star[0], main_star[1]]:
                            final_mask[region] = False
        
        
                    # Amplitude Fitting Process
                    loc_pos_all = maximum_position(stamp_bsub, labels=lbl_main,
                                                   index=np.arange(1, n_main+1))
                    loc_pos_sorted_all = sorted(loc_pos_all,
                                                key=lambda xy: stamp_bsub[xy[0], xy[1]],
                                                reverse=True)
                    # Fits Top 5 Brightest Sources (in a cutout with a significantly bright center source, this should only be this single main source)
                    loc_pos_sorted_all = loc_pos_sorted_all[:5]

                    
                    # Linear Amplitude Fit
                    amps, model_img = linear_amplitude_fit(stamp_bsub, loc_pos_sorted, psf_current, delta=5.0, max_iter=10, tol=1e-6)

                    # Residual Subtraction
                    resid = stamp_bsub - model_img

                    # Clip Negative Values to Zero 
                    resid = np.clip(resid, 0, None)

                    # Optional: Save 'stamp_bsub' and 'model_img' for First Cutout Only to Analyze
                    #if c_idx == 0:
                        #stamp_bsub_filename = os.path.join(output_dir, f"stamp_bsub_iter{it+1}_cutout1.fits")
                        #model_img_filename = os.path.join(output_dir, f"model_img_iter{it+1}_cutout1.fits")
                        #fits.PrimaryHDU(stamp_bsub).writeto(stamp_bsub_filename, overwrite=True)
                        #fits.PrimaryHDU(model_img).writeto(model_img_filename, overwrite=True)


                    # Mask Companion Cores
                    mask = companion_mask(stamp.shape, loc_pos_sorted,
                                          main_star_index=0, mask_radius=mask_radius)

            residuals.append(resid)
            masks.append(mask)

            # Optional: Save the First 5 Residuals Each Iteration for Analysis
            #if save_residuals and c_idx < 5:
                #resid_filename = os.path.join(
                    #output_dir, f"residual_cutout_{c_idx+1}_iter{it+1}.fits")
                #fits.PrimaryHDU(resid).writeto(resid_filename, overwrite=True)

        # Combine Residuals
        res_stack = np.stack(residuals, axis=0)
        mask_stack = np.stack(masks, axis=0)
        masked_res = np.where(mask_stack, res_stack, np.nan)

        # Median Stack Sigma-Clipped Residuals
        clipped = sigma_clip(masked_res, sigma=sigma_clip_level,
                             maxiters=5, axis=0)
        median_resid = np.nanmedian(clipped, axis=0)
        
        # Save Median Stack for Each Iteration
        resid_median = os.path.join(output_dir, f"rmstack_iter{it+1}.fits")
        fits.PrimaryHDU(median_resid).writeto(resid_median, overwrite=True)

        # Print Residual Statistics
        finite_vals = median_resid[np.isfinite(median_resid)]
        if len(finite_vals) > 0:
            min_r = np.min(finite_vals)
            max_r = np.max(finite_vals)
            mean_r = np.mean(finite_vals)
        else:
            min_r = max_r = mean_r = 0.0
        print(f"  Residual stats: min={min_r:.4g}, max={max_r:.4g}, mean={mean_r:.4g}")

        # Perform Residual Subtraction to Obtain Updated PSF
        new_psf = psf_current - alpha * median_resid

        # Clip Negative Values to Zero
        new_psf = np.clip(new_psf, 0, None)

        # Normalize the PSF
        psf_sum = np.sum(new_psf)
        if psf_sum > 0:
            new_psf /= psf_sum
        else:
            print("PSF collapsed to zero; stopping iteration.")
            break

        print(f"Iteration {it+1}: sum={np.sum(new_psf):.3f}, max={new_psf.max():.4f}.")

        # Save PSF Iteration to FITS File in Output Directory
        iter_file = os.path.join(output_dir, f"psf_iter{it+1}.fits")
        fits.PrimaryHDU(new_psf).writeto(iter_file, overwrite=True)

        # Update the 'psf_current' Assignment with the Updated PSF
        psf_current = new_psf.copy()
        

    # Final Trimming: Remove the 4 Outermost Rows/Columns (Resizing Leads to Artificial Zeroed Rows/Columns Around Edges)
    psf_final_trimmed = psf_current[4:-4, 4:-4]
    
    # Save Final PSF to FITS File in Output Directory
    final_file = os.path.join(output_dir, "psf_final.fits")
    fits.PrimaryHDU(psf_final_trimmed).writeto(final_file, overwrite=True)
    print(f"\nRefinement complete. Final PSF saved to {final_file}.")
    return psf_current


## Main Function Call

In [None]:
if __name__ == "__main__":

    # Define paths to mosaic file and output directory
    mosaic_path = 'path_to_mosaic'
    output_dir  = 'output_directory'

    # Function Parameters; current values can be used as a baseline, but should be adjusted by the user based on the input data
    
    cutout_size     = 105     # cutout size of PSF (pixels by pixels); after trimming in function, the final size of PSF is 'cutout_size' minus 4
    num_stars       = 200     # initial number of brightest sources to be considered for use in PSF; further refined to a smaller number later (approx. 75% of this value)
    threshold_sigma = 4.0     # threshold sigma used in initial detection of potential sources
    min_sep         = 5       # minimum pixel separation between sources needed to not be considered binary pairs or 'close' sources
    sat_limit       = 80      # limit for saturated stars, in terms of pixel-valued brightness
    iterations      = 2       # number of refinement iterations
    alpha           = 0.001   # scales the median stacked residual to determine how much the residual subtraction influences the current PSF; higher values result in greater changes, lower values in more gradual changes
    sigma_clip_level= 10.0    # sigma level for outliers to be excluded in the median residual stack during the iterative PSF refinement process
    mask_radius     = 2       # radius of mask for companion sources

    final_psf = refine_psf_from_mosaic(
        mosaic_path   = mosaic_path,
        output_dir    = output_dir,
        cutout_size   = cutout_size,
        num_stars     = num_stars,
        threshold_sigma = threshold_sigma,
        min_sep       = min_sep,
        sat_limit     = sat_limit,
        iterations    = iterations,
        alpha         = alpha,
        sigma_clip_level = sigma_clip_level,
        mask_radius   = mask_radius,
        save_residuals=True
    )

The above PSF Development and Iterative Refinement Code was developed by Emily McCallum as part of her Applied Mathematics Senior Thesis at Harvard College. Latest update: 26 Mar 2025