<img style="float: center;" src='https://github.com/STScI-MIRI/MRS-ExampleNB/raw/main/assets/banner1.png' alt="stsci_logo" width="900px"/> 

<a id="title_ID"></a>
# Improving JWST MIRI/MRS WCS using simultaneous imaging #

**Goal:** Improving JWST MIRI/MRS WCS using source detection on simultaneous imaging.

**Author:** Boris Trahin, Staff Scientist II, MIRI team

**Last updated:** January 29th, 2024

<div class="alert alert-block alert-info">
How to use:

- Change mrs_dir and mirim_dir directories below to MRS _rate and MIRIM simultaneous _cal files directories.

- Search for source detection parameters is automatic unless auto_find = False (takes much more time).

- Look at the results (plots below before and after the WCS correction).

- Change correct_mrs = True if number of detected sources/result looks good.
</div>

<div class="alert alert-block alert-warning">
Workbook still in progress.

- Optimization is needed ()

</div>

***
## Import packages

In [1]:
import os
import glob
import numpy as np
from astropy.io import fits
from astropy.visualization import interval
import matplotlib.colors as matcol
from matplotlib import cm
from astropy.coordinates import SkyCoord
from astropy.wcs import WCS
import astropy.units as u
from photutils.detection import DAOStarFinder, IRAFStarFinder
from astropy.stats import sigma_clipped_stats, sigma_clip
from astropy.units import Quantity
from astroquery.gaia import Gaia
import logging
from stdatamodels.jwst import datamodels
from itertools import product

# widget for interactive plots
%matplotlib widget 
# %matplotlib inline
import matplotlib.pyplot as plt
logging.getLogger("astroquery").setLevel(logging.ERROR)

***
## Get data

<div class="alert alert-block alert-info">
Source detection and offset calculation are performed on MIRIM/MRS simultaneous imaging _cal files (i.e. outputs from Image2Pipeline).

MRS files to be modified are the _rate files (i.e. outputs from Detector1Pipeline). 
The assign_wcs step in spec2 use the RA_REF/DEC_REF keywords with the instrument distortion model to build a combined WCS object for the data that gets stored in the ASDF extension, which is then used by everything downstream.
</div>

In [None]:
# Directory to MRS _rate images to correct (stage 1 Detector1Pipeline outputs)
mrs_dir = '/Users/btrahin/Data/03226/J1505+3721/MRS/Science/SLOWR1/stage1'

# Directory to MIRIM/MRS parallel _cal images to align (stage 2 Image2Pipeline outputs).
mirim_dir = '/Users/btrahin/Data/03226/J1505+3721/MIRIM/Science/FULL_MRS/FASTR1/stage2'

In [None]:
mrs_files = np.array(sorted(glob.glob(os.path.join(mrs_dir, '*[!wcs]_rate.fits'))))
mirim_files = np.array(sorted(glob.glob(os.path.join(mirim_dir, '*[!wcs]_cal.fits'))))

print(f'Found {len(mrs_files)} MRS images to process in {mrs_dir}')
print(f'Found {len(mirim_files)} MIRIM//MRS images to process in {mirim_dir}')

***
## Perform WCS correction

### Define useful functions

<div class="alert alert-block alert-info">
find_source is used to detect sources in the MIRIM images, using DAOStarFinder. Background is first determined and subtracted using sigma_clip statistics. sigma and threshold parameters can be modified.
</div>

In [None]:
# find sources using DAOStarFinder
def find_sources(hdu, sigma, threshold, method='DAO'):
    # https://photutils.readthedocs.io/en/stable/api/photutils.detection.DAOStarFinder.html
    data = hdu['SCI'].data
    header = hdu['SCI'].header
    hdr = hdu[0].header
    dq = hdu['DQ'].data
    w = WCS(header)
    mirifilt = hdr['FILTER']
    mask = np.zeros(data.shape, dtype=bool)
    # mask[dq!=0] = True
    badflag = np.where(np.bitwise_and(dq, datamodels.dqflags.pixel['DO_NOT_USE']))
    mask[badflag] = True
    wave = int("".join(filter(str.isdigit, mirifilt)))/100
    psf_arcsec = wave * 1e-6 / 6.5 / (2 * np.pi / (360 * 3600))
    psf_pixel = psf_arcsec/0.11
    mean, median, std = sigma_clipped_stats(data, sigma=sigma, mask=mask)
    if method == 'DAO':
        daofind = DAOStarFinder(fwhm=psf_pixel, threshold=threshold*std)
        sources = daofind(data - median, mask=mask)
    elif method == 'IRAF':
        iraffind = IRAFStarFinder(fwhm=psf_pixel, threshold=threshold*std)
        sources = iraffind(data - median, mask=mask)
    positions_xy = [x for x in np.transpose((sources['xcentroid'], sources['ycentroid']))]
    pix2radec = [w.pixel_to_world(x[0],x[1]) for x in positions_xy]
    positions_radec = [[x.ra.deg, x.dec.deg] for x in pix2radec]
    return positions_radec, positions_xy

<div class="alert alert-block alert-info">
find_gaia is used to find sources from the Gaia DR3 catalog in the imager FOV. 
</div>

In [None]:
# find Gaia DR3 sources in the FOV
def get_gaia_catalog(hdu):
    header = hdu['SCI'].header
    w = WCS(header)
    coord = SkyCoord(ra=header['CRVAL1'], dec=header['CRVAL2'], unit=(u.deg, u.deg))
    if 'CDELT1' in header.keys():
        radius = Quantity(3600 * header['CDELT1'] * header['NAXIS1'], u.arcsec)
    else:
        radius = Quantity(np.sqrt(header['PIXAR_A2']) * header['NAXIS1'], u.arcsec)
    Gaia.ROW_LIMIT = -1
    gaia_query = Gaia.query_object_async(coordinate=coord, radius=radius)
    reduced_query = gaia_query['ra', 'dec', 'ra_error', 'dec_error']
    return reduced_query    

def find_gaia_sources(hdu, reduced_query):
    header = hdu['SCI'].header
    w = WCS(header)
    radec2pix = w.world_to_pixel(SkyCoord(ra=reduced_query['ra'], dec=reduced_query['dec'], unit=(u.deg, u.deg)))
    radec2pix_transpose = np.transpose((radec2pix[0], radec2pix[1]))
    all_positions_radec = [[x['ra'], x['dec']] for x in reduced_query]
    all_err_positions_radec = [[x['ra_error'], x['dec_error']] for x in reduced_query]
    positions_xy = []
    positions_radec = []
    err_positions_radec = []
    for i in range(len(radec2pix_transpose)):
        if 0<radec2pix_transpose[i][0]<1032 and (0<radec2pix_transpose[i][1]<1024):
            positions_xy.append(radec2pix_transpose[i])
            positions_radec.append(all_positions_radec[i])
            err_positions_radec.append(all_err_positions_radec[i])
    return positions_radec, positions_xy, err_positions_radec

<div class="alert alert-block alert-info">
find_common_sources is used to create a list of common sources from the Gaia DR3 catalog and the ones detected in the image. Radius can be changed to increase/decrease the distance between sources for a match.
</div>

In [None]:
# find common Gaia DR3/image sources 
def find_common_sources(sources1, sources2, radius):
    # sources 1 detected, sources 2 gaia, radius in arcsec
    common_sources = []
    sources1_all, sources2_all = [], []
    for i in range(len(sources1)):
        deltas = sources2 - sources1[i]
        d = np.einsum('ij,ij->i', deltas, deltas)
        dx2 = (sources2[np.argmin(d)][0] - sources1[i][0])**2
        dy2 = (sources2[np.argmin(d)][1] - sources1[i][1])**2
        dist = np.sqrt(dx2 + dy2)
        if dist <= radius/0.11:
            sources2_all.append(np.argmin(d))
            sources1_all.append(i)
    for i in range(len(sources2_all)):
        if sources2_all.count(sources2_all[i]) == 1: # here to avoid taking sources in psf wings
            common_sources.append([sources1_all[i], sources2_all[i]])
    return common_sources

<div class="alert alert-block alert-info">
plot_images is used to plot the MIRIM image with the detected sources, Gaia sources and common sources used for the offset computation. The onesource parameter is used to zoom in on a source.
</div>

In [None]:
scale = interval.ZScaleInterval(n_samples=800, contrast=0.3, max_reject=0.5, min_npixels=5, krej=2.5,max_iterations=5)

def plot_images(hdu, sources, gaia, common, index, plot=True, onesource=False):
    # onesource: zoom in on a source
    if plot:
        ax = plt.subplot(len(mirim_files)//3+1, 3, index)
        data = hdu['SCI'].data
        header = hdu['SCI'].header
        (vmin, vmax) = scale.get_limits(data)
        normalization = matcol.Normalize(vmin=vmin, vmax=vmax)
        im = plt.imshow(data, origin = 'lower', norm=normalization,cmap = cm.grey)
        for g in gaia:

            ax.plot(g[0],g[1], marker='x', color='blue', markersize=0.5)
            circle_gaia = plt.Circle([g[0],g[1]], 10 , fill = False, color='blue', linewidth=0.3, linestyle='-')
            ax.add_artist(circle_gaia)
        for s in sources:
            ax.plot(s[0],s[1], marker='x', color='red', markersize=0.5)
            circle_source = plt.Circle(s, 10 , fill = False, color='red', linewidth=0.3, linestyle='-')
            ax.add_artist(circle_source)
        for c in common:
            ax.plot(sources[c[0]][0], sources[c[0]][1], marker='x', color='yellow', markersize=0.5)
            circle_common = plt.Circle(sources[c[0]], 10 , fill = False, color='yellow', linewidth=0.3, linestyle='-')
            ax.add_artist(circle_common)
        if len(gaia)>0:
            ax.plot([],[], marker='o', markerfacecolor='None', markeredgecolor='blue', markersize=5, markeredgewidth=0.3, linestyle='', label='Gaia DR3 source')
        if len(common)>0:
            ax.plot([],[], marker='o', markerfacecolor='None', markeredgecolor='yellow', markersize=5, markeredgewidth=0.3, linestyle='', label='Common sources')
        if len(sources)>0:
            ax.plot([],[], marker='o', markerfacecolor='None', markeredgecolor='red', markersize=5, markeredgewidth=0.3, linestyle='', label='Detected source')
        if onesource:
            ax.set_xlim(sources[common[0][0]][0]-20, sources[common[0][0]][0]+20)
            ax.set_ylim(sources[common[0][0]][1]-20, sources[common[0][0]][1]+20)
        ax.legend(fontsize=4, loc='lower left')
        ax.set_title(f"{hdu[0].header['FILENAME']}",fontsize=5)
        plt.tight_layout()

    else:
        return
    

<div class="alert alert-block alert-info">
shift_compute is used to compute the offset in RA and DEC. For now, it justs calculates the median after a sigma clipping.
</div>

In [None]:
# compute offset (to be modified)
def shift_compute(sources, gaia, common, err_gaia=None):
    gaia_common = np.array([gaia[x[1]] for x in common])
    sources_common = np.array([sources[x[0]] for x in common])
    deltas_ra = [sources_common[x][0]-gaia_common[x][0] for x in range(len(gaia_common))]
    deltas_dec = [sources_common[x][1]-gaia_common[x][1] for x in range(len(gaia_common))]
    if err_gaia:
        err_gaia_common = np.array([err_gaia[x[1]] for x in common])
        deltas_ra = [deltas_ra[ei] for ei in range(len(err_gaia_common)) if err_gaia_common[ei][0]/gaia_common[ei][0]<0.05]
        deltas_dec = [deltas_dec[ei] for ei in range(len(err_gaia_common)) if err_gaia_common[ei][1]/gaia_common[ei][1]<0.05]
    _, delta_ra, _ = sigma_clipped_stats(deltas_ra, sigma=2)
    _, delta_dec, _ = sigma_clipped_stats(deltas_dec, sigma=2)
    return delta_ra, delta_dec

<div class="alert alert-block alert-info">
obs_association is used to create a dictionnary of the MIRIM _cal files and the corresponding MRS _rate simultaneous observations.
</div>

In [None]:
# associates mirim obs with mrs obs (to be modified)
def obs_association(mirim_files, mrs_files):
    dic = {}
    alr = []
    for mirim in mirim_files:
        with fits.open(mirim) as hdu_mirim:
            hdr_mirim = hdu_mirim[0].header
            dic[mirim]=[]
            for mrs in mrs_files:
                if mrs not in alr:
                    with fits.open(mrs) as hdu_mrs:
                        hdr_mrs = hdu_mrs[0].header
                        if (hdr_mrs['PROGRAM']==hdr_mirim['PROGRAM']) & (hdr_mrs['VISITGRP']==hdr_mirim['VISITGRP']) & (hdr_mrs['OBSERVTN']==hdr_mirim['OBSERVTN']) & (hdr_mrs['EXPOSURE']==hdr_mirim['EXPOSURE']):
                            dic[mirim].append(mrs)
                            alr.append(mrs)
    return dic

### Source detection and WCS correction of the MRS observations 

<div class="alert alert-block alert-info">
User has to set a sigma and threshold value for sources detection.

Some useful parameters:
- auto_find tries to find the best sigma/threshold values in the sigma_values/threshold_values but takes much more time to run. Best sigma values are printed. Ranges can be adjusted for better detection.
- plot shows "interactive" images with Gaia/sources detected BEFORE correction
- correct_mrs = True if number of detected sources is OK
</div>

In [None]:
sigma = 2.5 # sigma value used in find_sources to estimate bkg and bkg noise using sigma-clipped statistics
threshold = 4. # x-sigma value used in find_sources to find stars in the image that have peaks approximately x-sigma above the background

auto_find = True # Try to find the best sigma/threshold combination to optimize common sources detection (slower)
plot = True # plot images with Gaia/sources detected BEFORE correction

correct_mrs = False # save WCS corrected MRS files
correct_mirim = True # save WCS corrected MIRIM files

method = 'IRAF' # method for sources detection (IRAF or DAO, default is DAO)

dic = obs_association(mirim_files, mrs_files) # associates mirim obs with mrs ones
print(dic)

sigma_values = np.arange(2.0, 5.0, 0.25)
threshold_values = np.arange(2.0, 8.0, 0.5)

<div class="alert alert-block alert-info">
Here is the main code.
</div>

In [None]:
if plot:
    fig = plt.figure(figsize=(10,10), dpi=150)

for i in range(len(mirim_files)):
    with fits.open(mirim_files[i]) as hdu:
        print(f'File {mirim_files[i]}')
        if auto_find:
            print('Finding best parameters for source detection...')
            num_matches = 0
            if i==0:
                gaia_catalog = get_gaia_catalog(hdu)
            for sigma_auto, threshold_auto in product(sigma_values, threshold_values):
                sources_positions_radec, sources_positions_xy = find_sources(hdu, sigma=sigma_auto, threshold=threshold_auto, method=method)
                gaia_positions_radec, gaia_positions_xy, err_gaia_positions_radec = find_gaia_sources(hdu, gaia_catalog)
                common_sources = find_common_sources(sources_positions_xy, gaia_positions_xy, radius=2)
                if num_matches <= len(common_sources) <= len(gaia_positions_xy):
                    num_matches = len(common_sources)
                    best_sigma = sigma_auto
                    best_threshold = threshold_auto
                    best_sources_radec, best_sources_xy = sources_positions_radec, sources_positions_xy
                    best_gaia_radec, best_gaia_xy, best_gaia_err = gaia_positions_radec, gaia_positions_xy, err_gaia_positions_radec
                    best_common = common_sources
            print(f'Best sigma, threshold values used: {best_sigma, best_threshold}')
        else:
            best_sources_radec, best_sources_xy = find_sources(hdu, sigma=sigma, threshold=threshold, method=method)
            if i==0:
                gaia_catalog = get_gaia_catalog(hdu)
            best_gaia_radec, best_gaia_xy, best_gaia_err = find_gaia_sources(hdu, gaia_catalog)
            best_common = find_common_sources(best_sources_xy, best_gaia_xy, radius=2)
            print(f'Sigma, threshold values used: {sigma, threshold}')
        delta_ra, delta_dec = shift_compute(best_sources_radec, best_gaia_radec, best_common, None)
        print(f'Number of sources detected: {len(best_sources_xy)}')
        print(f'Number of Gaia DR3 sources: {len(best_gaia_xy)}')
        print(f'Number of sources in common: {len(best_common)}')
        print(f'Shifts RA, DEC (deg): {delta_ra}, {delta_dec}\n')
        plot_images(hdu, best_sources_xy, best_gaia_xy, best_common, i+1, plot, onesource=False)
        if correct_mrs:
            for mrs_obs in dic[mirim_files[i]]:
                with fits.open(mrs_obs) as hdu_mrs:
                    hdr_mrs = hdu_mrs['SCI'].header
                    hdr_mrs['RA_REF']-=delta_ra
                    hdr_mrs['DEC_REF']-=delta_dec
                    hdr_mrs.append(('DELT_RA', delta_ra, 'WCS corr added to RA_REF from simult. imaging'))
                    hdr_mrs.append(('DELT_DEC', delta_dec, 'WCS corr added to DEC_REF from simult. imaging'))
                    hdu_mrs.writeto(mrs_obs.replace('_rate.fits', '_wcs_rate.fits'), overwrite=True)
        if correct_mirim:
            hdu['SCI'].header['RA_REF']-=delta_ra
            hdu['SCI'].header['DEC_REF']-=delta_dec
            hdu['SCI'].header['CRVAL1']-=delta_ra
            hdu['SCI'].header['CRVAL2']-=delta_dec
            hdu.writeto(mirim_files[i].replace('_cal.fits', '_wcs_cal.fits'), overwrite=True)

<div class="alert alert-block alert-info">
Plot positions of the Gaia sources AFTER WCS correction. 

Needs correct_mirim=True to get the corrected files (_wcs_cal.fits)
</div>

In [None]:
if correct_mirim:
    mirim_files_corr = np.array(sorted(glob.glob(os.path.join(mirim_dir, '*wcs_cal.fits'))))

    fig = plt.figure(figsize=(10,10), dpi=150)

    for i in range(len(mirim_files_corr)):
        with fits.open(mirim_files_corr[i]) as hdu:
            gaia_positions_radec, gaia_positions_xy, err_gaia_positions_radec = find_gaia_sources(hdu, gaia_catalog)
            plot_images(hdu, [], gaia_positions_xy, [], i+1, plot, onesource=False)