# 🛰️ Shift-and-Stack Tutorial for Astronomical Image Processing

This tutorial demonstrates how to apply the **shift-and-stack** technique to detect faint objects in CCD images.
The method aligns multiple exposures based on predicted motion and co-adds them to increase signal-to-noise ratio (S/N).

We will:
- Read and clean CCD images
- Use WCS to align based on known ephemerides
- Apply cosmic ray rejection
- Perform shift-and-stack using mean and median methods


In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from astropy.wcs import WCS
from astropy.time import Time
from astropy.coordinates import SkyCoord
from astroquery.jplhorizons import Horizons
import os, glob, scipy.signal as signal


## Defining File Locations and the JPL Record ID

We need to give the notebook the right directories to search when performing this analysis as well as the JPL Horizons Record ID for the specific orbital solution that we want to query. You can find that by searching the object on JPL Horizons and checking the 'Record ID' keyword in the output file. 

In [None]:
location = 'directory for files to analyze'
record_id = 'JPL Horizons orbit record ID to reduce errors when querying'

## 🔭 Correcting WCS Based on Known Star
We first correct the World Coordinate System (WCS) so that all images align based on a reference star.
This ensures the stacking step aligns the target object's motion properly. This option is not required, depending on the accuracy of the iamges existing WCS solution. 

In [6]:
def shift_wcs_solutions():
    drz_list = glob(location + '/science image folder')
    #For each file in the drz list you'll need to find a reference star in each image, compare the RA and DEC of what's in the file to what
    # is available from the catalogs, and put the delta here. As a default I've set everything to zero so you can run this step if you'd like.
    shift_array = np.array([['ifed01pfq',0,0],
                  ['ifed01pgq',0,0], 
                  ['ifed01phq',0,0],
                  ['ifed01piq',0,0],
                  ['ifed01pjq',0,0],
                  ['ifed01pkq',0,0], 
                  ['ifed01plq',0,0],
                  ['ifed01pmq',0,0]])
    
    for drz in drz_list:
        print(drz)
        shift_list = shift_array[shift_array[:,0]==drz[25:34]][0]
        with fits.open(drz) as hdul:
            wcs = WCS(hdul[1].header)
            print(shift_list[1])
            wcs.wcs.crpix += [int(shift_list[1]), int(shift_list[2])]
            hdul[1].header.update(wcs.to_header())
            hdul.writeto(drz[:-5]+'.fits',overwrite=True)  

## ✨ Cleaning the Images
First, let's define a basic cosmic-ray rejection filter that replaces isolated bright pixels. There are other versions, like LACosmic, that you can check out, but this one is simple. For any given pixel on the detector, it searches the pixels around it to identify the median. If the center pixel is more than 15 times higher that the surroundings, that is interpreted as a cosmic ray and "cleaned" by assigning that pixel the median value. 


In [None]:
def clean_image(data_slice):
    '''
    

    Parameters
    ----------
    data_slice : 2D image array or IFU slice
    
    Returns
    -------
    clean_slice : 2D image array or IFU slice with cosmic rays remove

    '''
    #Takes the IFU slice or 2D image array and the data quality slice to produce a masked array
    #Identify outliers
    
    clean_mask = np.ones_like(data_slice)
    rows,cols = data_slice.shape[0],data_slice.shape[1]
    clean_slice = data_slice
    #plt.imshow(clean_slice)
    for row in range(0,rows):
        for col in range(0,cols):
            perimeter_median = np.median(data_slice[row-1:row+1,col-1:col+1])
            if abs(data_slice[row,col]) >= abs(15*perimeter_median):
                clean_slice[row,col] = perimeter_median
                if abs(data_slice[row,col]) >= abs(25*perimeter_median):
                    print('Flagged value: '+str(data_slice[row,col]))
                    print('Replaced with: '+str(perimeter_median))
                    clean_mask[row,col]= 0
                    clean_slice[row,col] = perimeter_median
    return clean_slice

## 🧭 Image Alignment and Shift Measurement
Next let's define a function that finds the shift between images using cross-correlation, a common method for fine alignment. We make use of the Scipy.signal.correlate2d function, which minimizes the differences between images. 


In [None]:
def find_image_shift(image1, image2):
    """
    Compute the best-fit alignment shift between two images using cross-correlation.
    :param image1: First image
    :param image2: Second image
    :return: (dx, dy) shift required to align image2 to image1
    """
    correlation = signal.correlate2d(image1, image2, boundary='symm', mode='same')
    y_max, x_max = np.unravel_index(np.argmax(correlation), correlation.shape)
    center_y, center_x = np.array(image1.shape) // 2
    dy = y_max - center_y
    dx = x_max - center_x
    return dx, dy

## 📊 Stack and Co-add Images
Now we can set up our loop to work through the science images to clean, align, and stack.  This can be done with mean or median stacking.
Median stacking is robust to outliers (e.g. cosmic rays), while mean stacking preserves photometric accuracy better.


In [8]:
def crop_and_coadd(method = 'mean',best_fit=False,use_full=False,verbose = False,save_int = True):
    '''
    Parameters
    ----------
    method : str, 'mean' or 'median'
        Defines how images will be combined. The default is 'mean'.
    best_fit : Boolean, optional
        Toggle for if best fit routine is used to fine tune stack. The default is False.
    use_full : Boolean, optional
        Toggle to use full image to perform the best_fit. The default is False.
    verbose : Boolean, optional
        Toggle to print out intermediate informaton. The default is False.
    save_int : Boolean, optional
        Toggle to save intermediate image stacks. The default is True.

    Returns
    -------
    TYPE
        DESCRIPTION.

    '''
    c_skycoord_list = []
    err_list = []
    images = []
    if method == 'median':
        med_bool = True
    else:
        med_bool = False
    if best_fit == True:
        #Pick a file with a WCS solution you trust here.
        comp_data= fits.getdata('file_name_Here')
    for i,obs in enumerate(os.listdir(location)):
        # print(file)
        if 'DS_Store' in obs:
            continue
        file = location+obs+'/'+str(obs)+'_drc_v3.fits'
        header = fits.getheader(file,0)
        
        #Access the wcs_header
        wcs_header = fits.getheader(file,1)
        data = fits.getdata(file)
        drz_wcs1 = WCS(wcs_header)
        
        #These may be different depending on the datasets. Double check
        obs_date = header['DATE-OBS']
        obs_time = header['TIME-OBS']
        obs_start = header['EXPSTART']
        obs_end = header['EXPEND']
        
        #define the observing time for the middle point of the observations
        obs_obj = Time((obs_start+obs_end)/2,format='mjd',scale='utc')
        obstime_jd = obs_obj.jd 
        print(obstime_jd)
        
        #Query Horizons for the orbit
        obj = Horizons(id=record_id,id_type='smallbody',location='@hst',epochs=obstime_jd)
        eph = obj.ephemerides()
        
        #Create a SkyCoord for the object at the midpoint of the observations
        c = SkyCoord(eph['RA'][0], eph['DEC'][0], frame='icrs', unit='deg')
        c_skycoord_list.append(c)
        err_list.append([eph['RA_3sigma'][0],eph['DEC_3sigma'][0]])
        cen_skycoord = SkyCoord(c.ra,c.dec,frame='icrs')
        
        #Clean the data if it needs it (you can also use LACosmic or somethin')
        data_clean = clean_image(data)
        
        #convert the objects SkyCoord to the frame pixels
        pix_cen = drz_wcs1.world_to_pixel(cen_skycoord)
        pix_cen = [pix_cen[0], pix_cen[1]]
        if best_fit == True and i >= 1:
            #Perform iterative fitting to find best small shift to data
            if verbose == True:
                print('Starting alignment process for: '+str(obs))
            #pix_cen = align_images(data_clean,images,pix_cen,obs,median = med_bool)
            dx,dy = find_image_shift(data_clean, images)
            pix_cen = [pix_cen[0]+dx,pix_cen[1]+dy]
        else:
            pix_cen = drz_wcs1.world_to_pixel(cen_skycoord)
        x_low,x_high = int(pix_cen[0])-500,int(pix_cen[0])+500
        y_low,y_high = int(pix_cen[1])-500,int(pix_cen[1])+500
        
        data_trim = data_clean[x_low:x_high,y_low:y_high]
        #data_trim = data[int(pix_cen[0])-200:int(pix_cen[0])+200,
        #                      int(pix_cen[1])-200:int(pix_cen[1])-200]
        images.append(data_trim)
        if i == 0:
            data_og = data
            data_sum = data_trim
            #save a larger image for the fitting steps, centered on the best fit location for your object 
            x_low,x_high = int(pix_cen[0])-50,int(pix_cen[0])+50
            y_low,y_high = int(pix_cen[1])-50,int(pix_cen[1])+50
            data_trim = data_clean[x_low:x_high,y_low:y_high]
            comp_data = data_trim
        else:
            data_og += data
            data_sum += data_trim
            if save_int == True:
                #This saves the intermediate stack so that you can check for any mis-alignment in the process
                output = np.median(np.array(images), axis = 0)
                hdu_coadd = fits.PrimaryHDU(output)
                hdu_list = fits.HDUList([hdu_coadd])
                hdu_list.writeto('obj_img_stacked_intermediate'+str(i)+'.fits',overwrite=True)
            if use_full == True:
                x_low,x_high = int(pix_cen[0])-50,int(pix_cen[0])+50
                y_low,y_high = int(pix_cen[1])-50,int(pix_cen[1])+50
                data_trim = data_clean[x_low:x_high,y_low:y_high]
                
                comp_data = data_trim
    #Save new stack to FITS file
    hdu_coadd = fits.PrimaryHDU(data_sum/i)
    hdu_list = fits.HDUList([hdu_coadd])
    hdu_list.writeto('obj_img_stacked.fits',overwrite=True)
    
    data_med_array = np.array(images)
    if method == 'median':
        median = np.median(data_med_array, axis = 0)
        hdu_median = fits.PrimaryHDU(median)
        hdu_list = fits.HDUList([hdu_median])
        hdu_list.writeto('obj_img_median.fits',overwrite=True)
        plt.figure(4)
        plt.imshow(median, vmin=0, vmax=1,cmap='viridis')
        return median
    elif method == 'mean':
        mean = np.mean(data_med_array, axis = 0)
        hdu_mean = fits.PrimaryHDU(mean)
        hdu_list = fits.HDUList([hdu_mean])
        hdu_list.writeto('obj_img_mean.fits',overwrite=True)
        plt.figure(4)
        plt.imshow(mean, vmin=0, vmax=1,cmap='viridis')
        return mean

In [None]:
# Execute the full stacking pipeline
coadded_image = crop_and_coadd(method='median', verbose=True)