# Background Subtraction Tool

This notebook is designed to be used with the `AstroBkgInterp` tool.

## Import Packages

- `AstroBkgInterp` is our background interpolation tool
- `numpy` for array processing and math
- `atropy.io` for accessing the data
- `astropy.stats` for calculating statistics on the data
- `matplotlib` for plotting images and spectra
- `photutils.detection` for finding stars in the data

Optional: 
- `jwst` for running the JWST Calibration pipeline on JWST data (Note: you will need to install this package separately, as it is not included in the standard environment install for AstroBkgInterp). See https://github.com/spacetelescope/jwst for installation instructions.

In [None]:
import copy
import os, sys

# Import the background subtraction tool 
from AstroBkgInterp import AstroBkgInterp

# Import astropy packages
from astropy.io import fits
from astropy.stats import sigma_clipped_stats
import astropy.units as u

# Import packages for displaying images in notebook
from matplotlib import pyplot as plt
from matplotlib.patches import Circle
from mpl_toolkits.axes_grid1 import make_axes_locatable

# For data handling
import numpy as np

# To find stars in the MRS spectralcubes 
from photutils.detection import DAOStarFinder

## Set paths to Data and Outputs

In [None]:
path = '/path/to/data/Level3_ch1-2-3-4-shortmediumlong_s3d.fits'

## Open and display the data 

In [None]:
hdu = fits.open(path)
data = hdu[1].data

# set all NaN values to 0
data[np.isnan(data)] = 0

plt.imshow(data[7000], origin='lower')
plt.colorbar()

### Now detect the point source

Get a list of sources using a dedicated source detection algorithm. 

In [None]:
cube = np.zeros((data.shape[1], data.shape[2]))
for a in range(data.shape[1]):
    for b in range(data.shape[2]):
        cube[a, b] = np.median(data[:,a,b])
        
mean, median, std = sigma_clipped_stats(cube, sigma=3)

# Find sources at least 3* background (typically)
daofind = DAOStarFinder(fwhm=3.0, threshold=3.*std)
sources = daofind(cube-median) 
print("\n Number of sources in field:", len(sources))

In [None]:
# Plot all of the sources
plt.imshow(data[7000],origin='lower')
plt.colorbar()
plt.scatter(sources['xcentroid'], sources['ycentroid'], c="black", marker="+", s=50)

In the case where multiple sources are detected, find the brightest source and set this as the primary source.

In [None]:
peakpixval = np.zeros(len(sources['xcentroid']))
for count_s, _ in enumerate(sources):
    peakpixval[count_s] = cube[int(np.round(sources['xcentroid'][count_s])), int(np.round(sources['ycentroid'][count_s]))]
    
src_x, src_y = sources['xcentroid'][np.argmax(peakpixval)], sources['ycentroid'][np.argmax(peakpixval)]
print(f'peak pixel x = {src_x}')
print(f'peak pixel y = {src_y}')


### Set size of aperture and annulus for source masking

In [None]:
aper_rad = 5
ann_width = 6

Use the plot below to determine the desired aperture radius and annulus width for source masking.

In [None]:
plt.figure(figsize=(7,7))
plt.imshow(data[7000],origin='lower')
plt.colorbar()
plt.plot(src_x, src_y,'rx')

circ = Circle((src_x, src_y), radius = aper_rad, color='r', fill=False)
annin = Circle((src_x, src_y), radius = aper_rad+ann_width, color='r', fill=False)
plt.gca().add_patch(circ)
plt.gca().add_patch(annin)

# Run the background tool

In [None]:
bi = AstroBkgInterp()

# Source position
bi.src_y = src_y
bi.src_x = src_x

# Source masking params
bi.aper_rad = aper_rad
bi.ann_width = ann_width

# Background params
bi.bkg_mode = 'polynomial' 
bi.k = 3 
bi.bin_size = 5 

# Multiprocessing params
bi.pool_size = 12 
bi.cube_resolution = 'high'

diff, bkg, mask = bi.run(data)

In [None]:
fig, ax = plt.subplots(2, 2, figsize=(8, 10), constrained_layout=True)

# Titles for subplots
ax[0, 0].set_title('Data')
ax[0, 1].set_title('Source Masked')
ax[1, 0].set_title('Interpolated Bkg')
ax[1, 1].set_title('Residual')

# Define color limits and colormap
vmin, vmax = -10, 180

# Plot images
im_list = []
im_list.append(ax[0, 0].imshow(data[1000], vmin=vmin, vmax=vmax, origin='lower'))
im_list.append(ax[0, 1].imshow(mask[1000][0], vmin=vmin, vmax=vmax, origin='lower'))
im_list.append(ax[1, 0].imshow(bkg[1000], vmin=vmin, vmax=vmax, origin='lower'))
im_list.append(ax[1, 1].imshow(diff[1000], vmin=vmin, vmax=vmax, origin='lower'))

# Add a single shared colorbar
cbar = fig.colorbar(im_list[0], ax=ax, orientation='vertical', fraction=0.1, pad=0.03, shrink=0.7, aspect=30)

plt.show()

In [None]:
newdata = np.array([s for s in diff])

In [None]:
newhdu = copy.deepcopy(hdu)

In [None]:
newhdu[1].data = newdata
newhdu.writeto('newdata_high_res.fits', overwrite=True)

In [None]:
plt.imshow(newdata[5000],origin='lower')
plt.plot(src_x,src_y,'rx')

circ = Circle((src_x,src_y), radius=aper_rad,color='r',fill=False)
annin = Circle((src_x,src_y), radius=aper_rad+ann_width,color='r',fill=False)

plt.gca().add_patch(circ)
plt.gca().add_patch(annin)

-----------------

## For JWST data: Run the pipeline

In [None]:
import os
os.environ['CRDS_SERVER_URL'] = 'https://jwst-crds.stsci.edu'
os.environ['CRDS_PATH'] = os.environ['HOME']+'/crds_cache'

from jwst.extract_1d import Extract1dStep
from jwst import datamodels

In [None]:
step = Extract1dStep()

cube = datamodels.open('newdata_high_res.fits')

result = step.call(cube, 
                   subtract_background=False, 
                   center_xy=[src_x,src_y],
                   ifu_rfcorr=True)

In [None]:
result.to_fits('newdata_high_res_spec2.fits',overwrite=True)
res_pipe = fits.open('newdata_high_res_spec2.fits')

### Plot ABI background subtracted spectra

In [None]:
spec = res_pipe[1].data
WAVE = spec['WAVELENGTH']
FLUX = spec['FLUX']

In [None]:
FLUX_mjy = (FLUX*u.Jy).to(u.mJy)

In [None]:
FLUX_mjy.max()

In [None]:
plt.figure(figsize=(15,9))
plt.tick_params(size=7,width=2,direction='inout',labelsize=12)

plt.plot(WAVE, FLUX_mjy,
           lw=0.5,label='2D Interp Bkg')

plt.title('Background Subtracted Spectrum',fontsize=20)


plt.ylim(0,1)
plt.xlim(4.8,28)

plt.xlabel(r'$\mu m$',fontsize=15)
plt.ylabel('Flux (mJy)',fontsize=15)

plt.xscale('linear')
plt.yscale('linear')

plt.tight_layout()
plt.show()

#### Compare ABI spectrum with pipeline 

In [None]:
# Original pipeline data
#origdata = fits.open('/stage3/Level3_ch1-2-3-4-shortmediumlong_x1d.fits')
origdata = fits.open('Level3_ch2-short_x1d.fits')
origspec = origdata[1].data
origWAVE = origspec['WAVELENGTH']
origFLUX = origspec['FLUX']

In [None]:
# convert to mJy
origFLUX_mjy = (origFLUX*u.Jy).to(u.mJy)

In [None]:
fig, ax = plt.subplots(2,1,figsize=(15,9),sharex=True)
ax[0].tick_params(size=7,width=2,direction='inout',labelsize=12)
ax[1].tick_params(size=7,width=2,direction='inout',labelsize=12)


ax[0].plot(origWAVE, origFLUX,
           lw=0.5,c='c',label='Pipeline')
ax[0].plot(WAVE, FLUX,
           lw=0.5,c='m',label='ABI')
ax[1].plot(origWAVE, origFLUX-FLUX,
           lw=0.5,c='k',label='Difference')

ax[0].legend()
ax[1].legend()

ax[0].set_title('Background Subtraction comparison',fontsize=20)

ax[0].set_ylim(1e-4,5e-2)
ax[1].set_ylim(-0.012,0.012)

ax[0].set_xlim(4.8,28)
ax[1].set_xlim(4.8,28)

plt.xlabel(r'$\mu m$',fontsize=15)
ax[0].set_ylabel('Flux (Jy)',fontsize=15)
ax[1].set_ylabel('Flux (Jy)',fontsize=15)

ax[1].axhline(0,ls='--',c='r')

ax[0].set_xscale('log')
ax[0].set_yscale('log')
ax[1].set_yscale('symlog')

ax[0].set_xticks([5,7.5,10,15,20,25])
ax[0].set_xticklabels([5,7.5,10,15,20,25])

ax[1].set_xticks([5,7.5,10,15,20,25])
ax[1].set_xticklabels([5,7.5,10,15,20,25])

ax[1].set_yticks([-1e-2, -5e-3, -1e-3, 1e-3, 5e-3, 1e-2])
ax[1].set_yticklabels([-1e-2,-5e-3, -1e-3, 1e-3, 5e-3, 1e-2])

plt.tight_layout()
plt.show()