In [None]:
%pylab inline
import os
import miri_lrs_fm
import jwst
import webbpsf

## Find input files

In [None]:
!ls 03762/

In [None]:
fn_ta = '03762/jw03762002001_02101_00001_mirimage_cal.fits'
fn_taconfirm = '03762/jw03762002001_03102_00001_mirimage_cal.fits'

For the dispersed images, we don't want to use the MAST version of the CAL files

Need to use re-reductions of these, which do NOT include the dither subtractions. Can do this simply by
running spec2 pipeline manually, like: 

```
spec2 = jwst.pipeline.calwebb_spec2.Spec2Pipeline()
spec2.call(filename, save_results=True)
```

I did so in a subdirectory `no_dither_sub` when running this originally. 


In [None]:
# fn_sci_dith1 = 'no_dither_sub/jw03762002001_03103_00001_mirimage_cal.fits'
# fn_sci_dith2 = 'no_dither_sub/jw03762002001_03103_00002_mirimage_cal.fits'

Furthermore we can improve on that using a custom bad-pixel-finding-and-cleaning routine I wrote as part of this. 
This cleans up the outliers before modeling. THis is not essential but is helpful. 

The cell below will call and create those if they don't already exist

In [None]:
# We also want to run a bad pixel cleaning routine to clean up the outliers before modeling. 

fn_sci_dith1 = 'no_dither_sub/jw03762002001_03103_00001_mirimage_bpclean.fits'
fn_sci_dith2 = 'no_dither_sub/jw03762002001_03103_00002_mirimage_bpclean.fits'

# Do this the first time this notebook runs, no need to redo after
if not os.path.exists(fn_sci_dith1):
    print("Running bad pixel cleaning routine")

    for exp in [1,2]:
        fn = f'no_dither_sub/jw03762002001_03103_0000{exp}_mirimage_cal.fits'
        model = jwst.datamodels.open(fn)
        miri_lrs_fm.find_and_replace_outlier_pixels(model, save_path = 'no_dither_sub', 
                                                    nsigma=20, median_size=5)  # parameters slightly tuned here to optimize on these data

In [None]:
# Load those files into jwst datamodels objects
model_ta = jwst.datamodels.open(fn_ta)
model_taconfirm = jwst.datamodels.open(fn_taconfirm)
model_sci_dith1 = jwst.datamodels.open(fn_sci_dith1)
model_sci_dith2 = jwst.datamodels.open(fn_sci_dith2)


# We can get the host star coords at the time of observations (include proper motion etc)
# from the header metadata of the TA image:
host_star_coords = miri_lrs_fm.get_target_coords(model_ta)  

# Measure the WCS offset in the target acq image. 

Note, the observatory's actual pointing onboard was corrected using the target acq image -- but the WCS headers continue to be derived from the same guide star, thus effectively there's a constand offset in the WCS throughout the whole visit. We measure that here:

In [None]:
res, cov, wcsoffset = miri_lrs_fm.ta_position_fit_plot(model_ta, saveplot=True) 

In [None]:
miri_lrs_fm.plot_ta_verification_image(model_taconfirm, 
                                       wcs_offset=wcsoffset, box_size=80,
                                       host_star_coords=host_star_coords)

## Setup webbpsf sim to match that observation

In [None]:
miri = miri_lrs_fm.setup_sim_to_match_file(fn_taconfirm)

### Refine the WCS offset used, to better match the simulation to the data

This part I did iteratively, re-running the notebook multiples times.

You can set an X, Y offset that's applied in addition to the WCS offset derived above. 

In [None]:
# VALUE COPIED FFROM YSES 1 ANALYSES:
tweak_offset = (-0.227, -0.50)  # Derived from running the below without a tweak offset, and seeing what the
                               # residual is between the WCS offset coords and the Gaussian fit center coords. 

In [None]:
tweak_offset=None

In [None]:
tweak_offset = (-0.2, -0.06)  # Derived from a first iteration of this code, using offset = None
                              # then using the registration function below

In [None]:
tweak_offset

## Generate test PSF model for the TA Confirm observation

In [None]:
miri_lrs_fm.plot_taconfirm_psf_comparison(model_taconfirm, miri, 
                                          host_star_coords, wcsoffset, tweak_offset=tweak_offset,
                                          vmax=100)

## Check the dither observation distance

I *think* this should be the same in all cases, within the observatory dither precision of a few milliarcseconds.
This seems to be the case for at least the handful of MIRI LRS observations checked thus far. 

Side question: Empirically the dither move is close to, but not actually, (17,0) pixels. Intentionally??

In [None]:
dither_offset = miri_lrs_fm.measure_dither_offset(model_sci_dith1, model_sci_dith2, plot=True, saveplot=True)

# Generate PSF datacube over wavelengths

In [None]:
# This is slow the first time, but will save the cube to disk for subsequent reuse on later calls
psfs_cube, y_samp, wave_samp, converters = miri_lrs_fm.generate_lrs_psf_cube(model_taconfirm, model_sci_dith1, miri,
                                                                             host_star_coords, wcsoffset,
                                                                             tweak_offset=tweak_offset,
                                                                             #nlambda= 20,
                                                                             nlambda=None # do all wavelengths
                                                                             )

In [None]:
if tweak_offset is None:
    tweak_offset = (0,0)
tweak_offset_dith2 = list(np.asarray(tweak_offset) - dither_offset)  # Note the sign needs to be negative here!


# This is slow the first time, but will save the cube to disk for subsequent reuse on later calls
psfs_cube_d2, y_samp_d2, wave_samp_d2, converters_d2 = miri_lrs_fm.generate_lrs_psf_cube(model_taconfirm, 
                                                                                         model_sci_dith2, miri,
                                                                             host_star_coords, wcsoffset,
                                                                             tweak_offset=tweak_offset_dith2,
                                                                             )

### Disperse the PSF

This uses the wavecal to disperse the monochromatic sims. 

In [None]:
dispersed_model_d1 = miri_lrs_fm.generate_dispersed_lrs_model(psfs_cube, miri, 
                                                                   wave_samp, converters, 
                                                                   powerlaw=2, 
                                                                   add_cruciform=True)

In [None]:
dispersed_model_d2 = miri_lrs_fm.generate_dispersed_lrs_model(psfs_cube_d2, miri, 
                                                                   wave_samp_d2, converters_d2, 
                                                                   powerlaw=2, 
                                                                   add_cruciform=True)

### Estimate the background

We use the estimated background to remove it from the data, at least approximately, prior to fitting the scale factor between the model PSF and the data. A rough approximation is sufficient for that, it seems. 

In [None]:
bg, axes = miri_lrs_fm.estimate_background_spectrum(model_sci_dith1, miri, plot=True)

### Measure and refine the alignment

This is where we got the tweak_wcsoffset value used above... Ran this with the initial version of the PSFs
with no tweak_wcsoffset, measured the offset, copied that value above, deleted the pre-computed PSFs so it would recompute the next time, and re-ran the notebook. 

In [None]:
miri_lrs_fm.image_registration_dispersed_model(model_sci_dith2, dispersed_model_d2, bg) 

In [None]:
# In this case wec can try this on dither 1, 
# but it doesn't work super well since there's not much SNR on the speckles.
miri_lrs_fm.image_registration_dispersed_model(model_sci_dith1, dispersed_model_d1, bg) 

## Subtract the model from the data

In [None]:
sub_dith1 = miri_lrs_fm.scale_and_subtract_dispersed_model(model_sci_dith1, dispersed_model_d1, bg, 
                                                           converters, vmax=1e3) 

In [None]:
sub_dith2 = miri_lrs_fm.scale_and_subtract_dispersed_model(model_sci_dith2, dispersed_model_d2, bg, 
                                                           converters, vmax=1000)

In [None]:
sub_dith1 = jwst.datamodels.open('./jw03762002001_03103_00001_mirimage_starsub.fits')
sub_dith2 = jwst.datamodels.open('./jw03762002001_03103_00002_mirimage_starsub.fits')

In [None]:
miri_lrs_fm.display_dither_comparisons(model_sci_dith1, model_sci_dith2, 
                               sub_dith1, sub_dith2, converters)

In [None]:
!open .