# Reduce MIRI imaging data with JWST pipeline
import numpy as np
import os
from astropy.io import fits
from glob import glob
import multiprocessing

# Set the CRDS path and url
# should modify the path according to your computer!!
# This must be set before importing the pipeline!
os.environ['CRDS_PATH']= '/Users/users/gyang/gyang/crds_cache/jwst_pub'
os.environ['CRDS_SERVER_URL'] = 'https://jwst-crds-pub.stsci.edu'

import jwst
from jwst import pipeline
from jwst.pipeline import Detector1Pipeline
from jwst.pipeline import Image2Pipeline
from jwst.pipeline import Image3Pipeline
from jwst.associations import asn_from_list
from jwst.associations.lib.rules_level3 import Asn_Lv3Image
from jwst import datamodels
from jwst.resample.resample_utils import build_driz_weight

from astropy.table import Table
from astropy.wcs import WCS
from astropy.coordinates.sky_coordinate import SkyCoord
from astropy.stats import sigma_clipped_stats
import astropy.units as u

from scipy.stats import median_abs_deviation as mad

import matplotlib.pyplot as plt

%run -i 'write_rms_map.ipynb'
%run -i 'run_tkreg.ipynb'
%run -i 'super_bg.ipynb'
%run -i 'group_imgs.ipynb'


In [1]:
def find_neg_jump(frame1, frame2, frac_thresh=0.5):
    # Get the shape of the framerix
    Ny, Nx = frame1.shape
    # Calculate sign from the difference between frame2 & frame1
    sign_det_frame = ( (frame2-frame1)<0 )*1
    # Calculate the fraction of negative difference in each row
    neg_fracs = sign_det_frame.sum(axis=1) / Nx
    # Select the jump rows 
    jump_row_idxs = np.where(neg_fracs>frac_thresh)[0]
    
    if len(jump_row_idxs)==0: 
        # If no jump
        return None, None
    elif jump_row_idxs[-1]-jump_row_idxs[0]+1 == len(jump_row_idxs):
        # Check if the jump rows are together (expected)
        # Here we mask expand an addtional row to be safe
        lo_row_idx, up_row_idx = jump_row_idxs[0], jump_row_idxs[-1]
        if lo_row_idx != 0: lo_row_idx -=1
        if up_row_idx != Ny-1: up_row_idx +=1
        return lo_row_idx, up_row_idx
    else:
        print(jump_row_idxs)
        raise Exception('jump rows are not continuous?!')
        
def check_neg_jump(frame1, frame2, lo_row_idx, up_row_idx, frac_thresh=0.5):
    # Calculate negative jump fraction
    sign_det_frame = (frame2[lo_row_idx:up_row_idx+1]-frame1[lo_row_idx:up_row_idx+1] < 0) * 1
    neg_frac = sign_det_frame.sum() / (sign_det_frame.shape[0]*sign_det_frame.shape[1])
    # if satisfies the negative-jump criterion, return *False*
    if neg_frac>frac_thresh: return False
    else: return True

def set_dq(dq, m, n, k, i, j):
    # Make sure n>k
    if n<k: n, k = k, n
    
    # Set to DO_NOT_USE if the groups are at the beginning
    # of the integration 
    if k<=1: 
        # Set to DO_NOT_USE (Jane Morrison)
        # Clean up the JUMP_DET
        use = (dq[m,k:n+1,i:j+1]%8//4==1)
        dq[m,k:n+1,i:j+1][use] -= 4
        # Only need to reset for non-DO_NOT_USE pixels
        use = dq[m,k:n+1,i:j+1] % 2 ==0
        dq[m,k:n+1,i:j+1][use] += 1
        print("  set to DO_NOT_USE")
    else:
        # Only need to reset for non-jump pixels
        use = dq[m,k:n+1,i:j+1] % 8 // 4 ==0
        dq[m,k:n+1,i:j+1][use] += 4
        print("  set to JUMP_DET")
        
    print( '  %d pixels corrected' %len(np.where(use)[0]) )
    return dq

def validate(dq, m, n, i, j):
    if m<0 or m>=dq.shape[0]: return False
    if n<0 or n>=dq.shape[1]: return False
    if i<0 or i>=dq.shape[2] or j<0 or j>=dq.shape[2] or i>j: return False
    if (dq[m,n,i:j+1]%2==1).all(): return False
    return True

def correct_jump(datamodel):
    data = datamodel.data
    dq = datamodel.groupdq
    # Get the number of integraions and frames (groups) 
    Ni, Nf, Ny, Nx = data.shape
    # Iterate over each frame pair
    for m in range(Ni):
        for n in range(1, Nf):
            # Check if the frame is good
            if not validate(dq, m, n-1, 0, Ny-1) or not validate(dq, m, n, 0, Ny-1): continue
            # Try to find rows with negative jump
            # Note that we only search around n+-2
            i, j = find_neg_jump(data[m, n-1], data[m, n])
            if (i == None) or not validate(dq, m, n-1, i, j) or not validate(dq, m, n, i, j): continue
            # Try to locate the corresponding positive jump
            # and set the corrupted frames/rows to DO_NOT_USE
            if   validate(dq, m, n-2, i, j) and check_neg_jump(data[m, n-2], data[m, n],  i, j):
                k = n-1
            elif validate(dq, m, n+1, i, j) and check_neg_jump(data[m, n-1], data[m, n+1], i, j):
                k = n+1
            else:
                if   validate(dq, m, n-3, i, j):
                    if check_neg_jump(data[m, n-3], data[m, n],  i, j): k=n-2
                    else: k=n+2
                elif validate(dq, m, n+2, i, j):
                    if check_neg_jump(data[m, n-1], data[m, n+2], i, j): k=n+2
                    else: k=n-2
                else:
                    raise Exception("sorry, can't locate positive jump due to incomplete data")
            # Set bad segements to "DO_NOT_USE"
            print('ITG=%d, neg=%d, pos=%d, row=%d-%d' %(m+1, n+1, k+1, i+1, j+1))
            dq = set_dq(dq, m, n, k, i, j)
            
    datamodel.groupdq = dq
    return datamodel

# Show the jump effect
%run -i '../env_set.ipynb'
raw_data = fits.getdata('../data/raw/jw01345/mastDownload/JWST/jw01345001001_14101_00002_mirimage/jw01345001001_14101_00002_mirimage_uncal.fits')

# Plot the ramp data of two pixels
itg_idx = 4

plt.figure(figsize=(9,6))
ramp1 = raw_data[itg_idx, :, 850, 700]/1e4
ramp2 = raw_data[itg_idx, :, 150, 700]/1e4
group_id = np.arange(len(ramp1)) + 1

plt.subplot(211)
plt.plot(group_id, ramp1, 'ok')
plt.ylim(0.5, 5.2)
plt.vlines(4, 0, 6, colors='k', linestyles='dashed')
plt.ylabel(r'DN/$10^4$')
plt.text(29, 0.8, 'x=700, y=850', fontsize=15)

plt.subplot(212)
plt.plot(group_id, ramp2, 'ok', label='ramp data')
plt.plot(group_id[3:5], ramp2[3:5], 'sC3', markerfacecolor='none', ms=10, label='masked')
plt.legend(loc=4)
plt.vlines(4, 0, 6, colors='k', linestyles='dashed')

plt.ylim(0.5, 5.2)
plt.ylabel(r'DN/$10^4$')
plt.xlabel('group')
plt.text(29, 1.8, 'x=700, y=150', fontsize=15)

output_dir = '../data/jump/'

# Find the raw-data path for the two jump F2100W exposures 
input_dirs = sorted( glob("../data/raw/jw01345/mastDownload/JWST/jw0134500*001_14101*mirimage/") )

## Stage 1
# Iterate over each input directory 
for input_dir in input_dirs:
    # Search for raw files in the directory
    raw_files = sorted(glob(input_dir+'*_uncal.fits'))
    # Iterate over each raw file
    for file in raw_files:
        # Set up the pipeline
        det1 = Detector1Pipeline()
        det1.save_results = False
        # Use parallel, can be 'none', 'quarter', 'half', and 'all', or a fractional number
        det1.ramp_fit.maximum_cores = 'quarter'
        det1.jump.maximum_cores = det1.ramp_fit.maximum_cores
        # Run till jump_fit 
        # (i.e., skip the last two steps)
        det1.ramp_fit.skip = True
        det1.gain_scale.skip = True
        res0 = det1.run(file)
        
        # Do the correction
        res = correct_jump(res0)
        #res = res0.copy()
        
        # Run ramp_fit
        # (gain_scale is skipped by default)
        det1.ramp_fit.skip = False
        det1.ramp_fit.save_results = True
        res = det1.ramp_fit(res)
        
# Rename and move the products 
rate_fnames = glob('*0_ramp_fit.fits')
rateints_fnames = glob('*1_ramp_fit.fits')
for rate_fname, rateints_fname in zip(rate_fnames, rateints_fnames):
    os.system('mv ' + rate_fname     + ' ' + output_dir + \
               rate_fname.replace('0_ramp_fit', 'rate'))
    os.system('mv ' + rateints_fname + ' ' + output_dir + \
               rateints_fname.replace('1_ramp_fit', 'rateints'))

# Search for stage-1 result files
stage1_files = sorted(glob(output_dir+'*_rate.fits'))
flt_groups, flt_vals = group_imgs(stage1_files, 'band')
# Iterate over each file
for files, flt_val in zip(flt_groups, flt_vals):
    for file in files:
        # Set up the pipeline
        img2 = Image2Pipeline()
        img2.save_results = True
        # Run the pipeline
        img2.run(file)
        # Move the result to output directory 
        os.system('mv *.fits ' + output_dir)

# Subtract background for stage-2 result 
# Using Casey's algorithm
# Search for original stage-2 result files
stage2_files = sorted(glob(output_dir+'*_cal.fits'))
# Remove strips or global background
Nc = min(multiprocessing.cpu_count()//4, len(stage2_files))
with multiprocessing.Pool(processes=Nc) as pool_obj:
    pool_obj.map(rm_strip, stage2_files)
    
# Remove background 
stage2_files = sorted(glob(output_dir+'*stprm.fits'))
# Extract filter info for each file 
flt_groups, flt_vals = group_imgs(stage2_files, 'band')
# subtract the background group by group
for flt, files in zip(flt_vals, flt_groups):
    print(flt+':')
    # Define the parallel function
    def do_f0(f0): 
        super_bg(f0, files)
        return
    # Decide number of cores to use
    Nc = min(multiprocessing.cpu_count()//4, len(files))
    # Run
    with multiprocessing.Pool(processes=Nc) as pool_obj:
        pool_obj.map(do_f0, files)

## Stage 3

# Search for stage-2 result files
stage2_files = sorted(glob(output_dir+'*14101*_stprm_bgsub.fits'))
name = 'bad_uncorrected_'
#name = 'bad_corrected_'
#stage2_files = sorted(glob(output_dir+'*12101*_stprm_bgsub.fits'))
#name = 'good_'

# Group files by field_filter
ff_groups, ff_vals = group_imgs(stage2_files, 'field_band')

# Define the function to parallel
def do_ff(ff_idx):
    group = ff_groups[ff_idx]
    
    # Create an association 
    asn = dict( asn_from_list.asn_from_list(group, rule=Asn_Lv3Image, 
                product_name='l3_results', asn_type="image3") )
    # Initialize stage 3 pipeline
    # In the following, we use ".copy()" as input to avoid the input being changed 
    img3 = Image3Pipeline()
    # Load the data models
    dms = datamodels.open(asn)
 
    cr_sigclip=1
    tkreg_res = run_tkreg(dms, img3, cr_sigclip=cr_sigclip, run2=True)

    # match bkg
    img3.skymatch.subtract = True
    img3.skymatch.skymethod = 'local'
    res = img3.skymatch.run(tkreg_res.copy())

    # Reject cosmic-ray 
    res = img3.outlier_detection.run(res)
    
    # First resample to get the coordinate transformation
    img3.resample.rotation = -49.7
    img3.resample.pixel_scale = 0.09
    img3.resample.output_shape = 1500, 1500
    temp_res = img3.resample(res.copy())    
    
    # Calculate the trangential point coordinates
    tg_ra, tg_dec = 214.825, 52.825
    tg_x, tg_y = temp_res.get_fits_wcs().world_to_pixel( SkyCoord(tg_ra*u.degree, tg_dec*u.degree, frame='icrs') )
    # Round to 0.5 to align with HST 
    tg_x, tg_y = np.round(tg_x)+60.5, np.round(tg_y)+60.5
    # Resample again to get the TG point correct
    img3.resample.crval = tg_ra, tg_dec
    img3.resample.crpix = tg_x, tg_y
    res = img3.resample(res)
    # Output 
    output_fname = output_dir + name + ff_vals[ff_idx].lower() + '_i2d.fits'
    res.save(output_fname)

# Decide number of cores to use
Nc = min(multiprocessing.cpu_count()//4, len(ff_vals))
# Run
with multiprocessing.Pool(processes=Nc) as pool_obj:
    pool_obj.map(do_ff, range(len(ff_vals)))

# Clean up temporary files from intermediate steps
os.system('rm *.fits')

# Produce the RMS map
stage3_files = sorted(glob(output_dir + '*i2d.fits'))
# Array to save RMS scaling factor
rms_facs = []
for file in stage3_files:
    rms_facs.append( write_rms_map(file, input_type='wht') )

from astropy.io import fits
from glob import glob
from astropy.table import Table
from astropy.wcs import WCS
from astropy.wcs.utils import proj_plane_pixel_scales
%run -i '../../tools/astrometry/cut_img.ipynb'

# Re-write the science images 

# Read science images 
sci_files = glob.glob('../data/jump/*f2100w*i2d.fits')
sci_files.extend( glob.glob('../data/jump/uncorrected/*f2100w*i2d.fits') )

# Itereate over each file
for sci_file in sci_files:
    # Load the multi-extension fits file
    with fits.open(sci_file) as hdu:
        sci_hdu = hdu[1]
        cdelt = sci_hdu.header['CDELT1']
        sci_hdu.header['CD1_1']=sci_hdu.header['PC1_1'] *cdelt
        sci_hdu.header['CD1_2']=sci_hdu.header['PC1_2'] *cdelt
        sci_hdu.header['CD2_1']=sci_hdu.header['PC2_1'] *cdelt
        sci_hdu.header['CD2_2']=sci_hdu.header['PC2_2'] *cdelt
        sci_hdu.header.remove('PC1_1')
        sci_hdu.header.remove('PC1_2')
        sci_hdu.header.remove('PC2_1')
        sci_hdu.header.remove('PC2_2')   
        sci_hdu.header.remove('CDELT1')
        sci_hdu.header.remove('CDELT2')
        
        # Dump the science extension to an image file
        # Note: can't use "sci_hdu.write", b/c that will place the image 
        # at the fits extension, which is incompatible with TPHOT  
        fits.writeto( '../data/tphot/images/' + \
                      sci_file.split('/')[-1].replace('i2d', 'sci'),
                      sci_hdu.data, header=sci_hdu.header, overwrite=True)

input_dirs = sorted( glob("../data/raw/jw01345/mastDownload/JWST/jw01345001001_14101_00003_mirimage/") )

## Stage 1
# Iterate over each input directory 
for input_dir in input_dirs:
    # Search for raw files in the directory
    raw_files = sorted(glob(input_dir+'*_uncal.fits'))
    # Iterate over each raw file
    for file in raw_files:
        # Set up the pipeline
        det1 = Detector1Pipeline()
        det1.save_results = False
        # Use parallel, can be 'none', 'quarter', 'half', and 'all', or a fractional number
        det1.ramp_fit.maximum_cores = 'quarter'
        det1.jump.maximum_cores = det1.ramp_fit.maximum_cores
        # Run till jump_fit 
        # (i.e., skip the last two steps)
        det1.ramp_fit.skip = True
        det1.gain_scale.skip = True
        
        det1.jump.save_results = True
        res0 = det1.run(file)

res = correct_jump(res0.copy())
res.save('../data/jump/jw01345001001_14101_00003_mirimage_jump_corrected.fits')

# Run ramp_fit
# (gain_scale is skipped by default)
det1.ramp_fit.skip = False
det1.ramp_fit.save_results = True
res = det1.ramp_fit(res)

# Rename and move the products 
rate_fnames = glob('*0_ramp_fit.fits')
rateints_fnames = glob('*1_ramp_fit.fits')
for rate_fname, rateints_fname in zip(rate_fnames, rateints_fnames):
    os.system('mv ' + rate_fname     + ' ' + output_dir + \
               rate_fname.replace('0_ramp_fit', 'rate'))
    os.system('mv ' + rateints_fname + ' ' + output_dir + \
               rateints_fname.replace('1_ramp_fit', 'rateints'))