In [2]:
import sep
import multiprocessing

import pyregion
from astropy.convolution import Tophat2DKernel
from scipy.ndimage import binary_dilation
from astroscrappy import detect_cosmics
from astropy.table import Table
import numpy as np
from astropy.io import fits

In [3]:
# The super-bkg method
def rm_strip(cal_file, rm_hz=False, rm_vt=False):
    with datamodels.open(cal_file) as dm:
        # Read science image and build weight
        sci_img = dm.data.copy()
        err_map = dm.err.copy()
        #wht = build_driz_weight(dm, weight_type='ivm', good_bits='~DO_NOT_USE+NON_SCIENCE')
        #mask = (wht==0)
        # Mask bad pixels
        mask = (np.isfinite(err_map)==False) | (dm.dq>4)
        err_map[mask] = np.inf
        # Get the global bkg values
        glob_bkg = get_glob_bkg(sci_img, err_map)
        # Subtract horizontal stripping
        if rm_hz:
            hz_strips = sigma_clipped_stats(sci_img, mask=mask, sigma=2, cenfunc='median', stdfunc='std', axis=1)[1]
            sci_img = (sci_img.T-hz_strips).T
        # Subract vertical stripping
        if rm_vt:
            vt_strips = sigma_clipped_stats(sci_img, mask=mask, sigma=2, cenfunc='median', stdfunc='std', axis=0)[1]
            sci_img -= vt_strips
        # If not horizontal nor vertical, subtract a constant global background
        if (rm_hz or rm_vt)==False:
            sci_img -= glob_bkg
        else:
            # Set all nan (may be caused by strip removal)
            mask[np.isnan(sci_img)] = True
            sci_img[np.isnan(sci_img)] = 0
            err_map[mask] = np.inf
        # Replace the old data
        dm.data = sci_img
        #dm.err = err_map
        dm.save(cal_file.replace('.fits','')+'_stprm.fits', overwrite=True)
    return glob_bkg
        
def get_glob_bkg(dat, err):
    mask = (np.isfinite(err)==False)
    mask[:,:300]=True
    mask[:,1026:]=True
    ok = (mask==False)
    med_global = np.nanmedian(dat.astype('float')[ok])
    return med_global

def det_src(img, sigma_threshold=2, bw=72, bh=72, fw=5, fh=5,
            clean_edge=True, cr_sigclip=1., minarea=5, r_dilate=-1,
            kernel=np.array([[1,2,1], [2,4,2], [1,2,1]])):
    # mask_threshold=0.3, no longer used
    # Open data model
    #with datamodels.open(img) as dm:
        # Read science image and build weight
        #dat = dm.data.copy()
        #wht = build_driz_weight(dm, weight_type='ivm', good_bits='~DO_NOT_USE+NON_SCIENCE')
        # subtract a constant background first 
        # This is no longer needed becuase we remove the stripes before
        #dat -= get_glob_bkg(dat, dm.err)
    
    # Check input type
    if type(img) == str:
        # Read data
        dat = fits.getdata(img, 'sci')
        err = fits.getdata(img, 'err')
        dq = fits.getdata(img, 'dq')
    elif type(img) == jwst.datamodels.ImageModel:
        dat = img.data.copy()
        err = img.err.copy()
        dq = img.dq.copy()
    # Mask cosmic rays
    cr_map = detect_cosmics(dat, sigclip=cr_sigclip)[0]*1
    err[np.where(cr_map)] = np.inf
    err[dq>4] = np.inf
    # subtract a background, then find objects.     
    #mask = (wht < np.max(wht)* mask_threshold) 
    wht = 1./err**2
    mask = (np.isfinite(err)==False)
    wht[mask] = np.nan
    wht /= np.nanmedian(wht)
    # Estimate a spatial dependent background
    bkg = sep.Background(dat.astype('float'), mask=mask, bw=bw, bh=bh, fw=fw, fh=fh)
    bkg.back()[mask] = np.nan
    # identify objects
    wdat = (dat-bkg.back())  * wht
    wdat[mask] = np.nan
    wbkg = sep.Background(wdat, mask=mask, bw=bw, bh=bh, fw=fw, fh=fh)  
    objects = sep.extract(wdat, sigma_threshold, mask=mask, filter_kernel=kernel,
                          err=wbkg.globalrms, segmentation_map=True, minarea=minarea)
    src, seg = objects[0], objects[1]
    # Convert src to astropy table format
    src = Table(src)
    # Add an ID column 
    src.add_column( np.arange(len(src))+1, name='id' )
    # Clean sources close to detector edge if needed
    if clean_edge: 
        # Read the bad master region file
        regs = pyregion.open("../data/master_noisy_region.reg")
        # Iterate over each bad region
        bad_src_idxs = []
        for reg in regs:
            # Get the four points
            x1, x2 = reg.coord_list[0]-reg.coord_list[2]/2, reg.coord_list[0]+reg.coord_list[2]/2
            y1, y2 = reg.coord_list[1]-reg.coord_list[3]/2, reg.coord_list[1]+reg.coord_list[3]/2
            # Get sources in the region
            bad_src_idxs = np.append(bad_src_idxs,
                    np.where( (src['x']>x1) & (src['x']<x2) & (src['y']>y1) & (src['y']<y2) )[0] )
        bad_src_idxs = np.sort(np.unique(bad_src_idxs)).astype(int)
        # Remove the bad sources in both catalog and segmentaion
        good_src_idxs = np.delete(np.arange(len(src)), bad_src_idxs)
        src = src[good_src_idxs]
        for bad_id in bad_src_idxs+1:
            seg[seg==bad_id] = 0
    # Dilate the segmentaion? (but all pixel will be set to 1)
    if r_dilate>0:
        footprint = Tophat2DKernel(radius=r_dilate)
        seg = binary_dilation(seg, footprint.array).astype(int)
            
    return src, dat, wht, bkg, seg

def super_bg(files, sigma_threshold=2, bw=72, bh=72, fw=5, fh=5, minarea=20,
             kernel=np.array([[1,2,1], [2,4,2], [1,2,1]]), r_dilate=10,
             clean_edge=True, write_bg=False):
    # read in each image.  Identify objects and mask with background from SEP
    x, y = np.shape( fits.getdata(files[0]) )
    darr = np.zeros( [len(files), x, y] )
    # Check number of files 
    if len(files)<2:
        raise Exception("Too few input files?!")
    elif len(files)==2:
        print('Warning: building master background from a single file')
        # Use different parameters if building from a single image
        clean_edge = False
        r_dilate = -1
        minarea = 5
        sigma_threshold = 1.2
    else:
        print('Building master background from a %d files' %(len(files)-1))
    # Iterate over each file to get the background     
    for i,f in enumerate(files):
        # Detect sources 
        src, dat, wht, bkg, seg = det_src(f, sigma_threshold=sigma_threshold,
            r_dilate=r_dilate, bw=bw, bh=bh, fw=fw, fh=fh, cr_sigclip=1.,
            minarea=minarea, kernel=kernel, clean_edge=clean_edge)
        print("%d src detected!" %len(src))
        # Fill in NaN or bkg, depending on number of files available
        if len(files)>2:
            dat[seg > 0] = np.nan
        else:
            dat[seg > 0] = bkg.back()[seg > 0]        
        # Save the data
        darr[i] = dat
    # Subtract background for each file
    for idx0, f0 in enumerate(files):
        #take median 
        oth_idxs = np.delete(range(len(files)), idx0)
        dmed = np.nanmedian(darr[oth_idxs], axis=0)
        # Set nan to 0 to avoid NaN in the result
        dmed[np.isnan( dmed )] = 0
        # subtract and writeout
        with fits.open(f0) as hdu:
            # Subtract the background
            hdu['sci'].data -= dmed
            mask = (np.isfinite(hdu['err'].data)==False) | np.isnan(hdu['sci'].data) | hdu['dq'].data>4
            hdu['sci'].data[mask] = 0
            hdu['err'].data[mask] = np.inf
            hdu.writeto( f0.replace('.fits','')+'_bgsub.fits', overwrite=True)
            # Write background if needed
            if write_bg: 
                hdu['sci'].data = dmed
                hdu.writeto( f0.replace('.fits','')+'_bg.fits', overwrite=True)
        print(f0.replace('.fits','')+'_bgsub.fits' + "  Done!")
    
    return 

%run -i './group_imgs.ipynb'
from glob import glob
import os
output_dir = "../data/reduced/"

stage1_files = sorted(glob(output_dir+'*_rate.fits'))

flt_groups, flt_vals = group_imgs(stage1_files, 'field_band')
flt_groups[ flt_vals.index('MIRI1_F770W') ]

# Make the plot for the paper
f0 = 'jw01345001001_06101_00001_mirimage_cal.fits'

# Mask and subtract a median background for the original cal image
with fits.open('../data/reduced/' + f0) as hdu:
    # Subtract the background
    mask = (np.isfinite(hdu['err'].data)==False) | np.isnan(hdu['sci'].data) | hdu['dq'].data>4
    hdu['sci'].data[mask] = 0
    hdu['sci'].data[~mask] -= np.nanmedian( hdu['sci'].data[~mask] )
    hdu['err'].data[mask] = np.inf
    # Write to file
    f1 = f0.replace('.fits','')+'_forPlot.fits'
    hdu.writeto('../data/reduced/for_plot/'+f1, overwrite=True)
    
# Mask for the strmp image
f0 = f0.replace('.fits','')+'_stprm.fits'
# Mask and subtract a median background for the original cal image
with fits.open('../data/reduced/' + f0) as hdu:
    # Subtract the background
    mask = (np.isfinite(hdu['err'].data)==False) | np.isnan(hdu['sci'].data) | hdu['dq'].data>4
    hdu['sci'].data[mask] = 0
    hdu['err'].data[mask] = np.inf
    # Write to file
    f1 = f0.replace('.fits','')+'_forPlot.fits'
    hdu.writeto('../data/reduced/for_plot/'+f1, overwrite=True)
    
# Do nothing but directly copy the bgsub image
f0 = f0.replace('.fits','')+'_bgsub.fits'
with fits.open('../data/reduced/' + f0) as hdu:
    # Write to file
    f1 = f0.replace('.fits','')+'_forPlot.fits'
    hdu.writeto('../data/reduced/for_plot/'+f1, overwrite=True)

def super_bg(f0, files, sigma_threshold=2, bw=72, bh=72, fw=5, fh=5, minarea=20,
             kernel=np.array([[1,2,1], [2,4,2], [1,2,1]]), r_dilate=10):
    # read in each image.  Identify objects and mask with background from SEP
    otherfiles = []
    for f in files :
        if f!=f0 : otherfiles.append(f)
    for i, f in enumerate(otherfiles) :
        if i==0 :
            x, y = np.shape( fits.getdata(f) )
            darr = np.zeros( [ len(otherfiles), x, y] )
        # Read the image
        src, dat, wht, bkg, seg = det_src(f, sigma_threshold=sigma_threshold, r_dilate=r_dilate,
                bw=bw, bh=bh, fw=fw, fh=fh, cr_sigclip=1., minarea=minarea, kernel=kernel)
        print("%d src detected!" %len(src))
        # mask objects and replace with background: 
        #dat[seg > 0] = bkg.back()[seg > 0]
        dat[seg > 0] = np.nan
        darr[i] = dat
        
    #take median 
    dmed = np.nanmedian(darr, axis=0)
    dmed[np.isnan( dmed )] = 0
    # subtract and writeout
    with fits.open(f0) as hdu:
        # Subtract the background
        hdu['sci'].data -= dmed
        mask = (np.isfinite(hdu['err'].data)==False) | np.isnan(hdu['sci'].data) | hdu['dq'].data>4
        hdu['sci'].data[mask] = 0
        hdu['err'].data[mask] = np.inf
        hdu.writeto( f0.replace('.fits','')+'_bgsub.fits', overwrite=True)
    print(f0.replace('.fits','')+'_bgsub.fits' + "  Done!")
    
    return 