# Simulate NIRCam Images

## Using MIRAGE and the JWST Pipeline

https://github.com/spacetelescope/mirage/blob/master/examples/Imaging_simulator_use_examples.ipynb

Here we use MIRAGE to simulate NIRCam imaging based on HST observations of the galaxy cluster MACS0647+70.  
Our full APT program executes 160 exposures = 4 dithers x 4 filters x 10 detectors.
Here we simulate images in one filter in one NIRCam Module A.
For the short wavelength filters, that will be 4 dithers x 4 detectors = 16 images.
For the long wavelength filters,  that will be 4 dithers x 1 detector  =  4 images.

Then we run the JWST pipeline to combine all 16 or 4 exposures into one image.

Inputs:  
APT file outputs: .xml, .pointing
Galaxy catalog incl. RA, Dec, Sersic fit parameters, magnitude

Outputs:  
MIRAGE Simulated NIRCam FITS images (raw.fits; don't save linear.fits)
JWST Pipeline Reduced NIRCam FITS images and catalog (id2.fits, cat.ecsv)

JWST Pipeline run including recommendations from CEERS program DR1

Their notebook ceers_nircam_reduction.ipynb offers detailed help and instructions

https://ceers.github.io/releases.html

https://jwst-docs.stsci.edu/jwst-data-reduction-pipeline

https://jwst-docs.stsci.edu/jwst-data-reduction-pipeline/algorithm-documentation/stages-of-processing

https://jwst-pipeline.readthedocs.io/en/latest/jwst/introduction.html

https://jwst-pipeline.readthedocs.io/en/latest/jwst/pipeline/

https://jwst-pipeline.readthedocs.io/en/latest/jwst/pipeline/calwebb_image3.html

"In order to process and combine multiple images, an ASN file must be used as input, listing the exposures to be processed."
    
https://jwst-pipeline.readthedocs.io/en/latest/jwst/associations/asn_from_list.html

https://jwst-pipeline.readthedocs.io/en/latest/jwst/associations/level3_asn_technical.html

https://github.com/spacetelescope/nircam_calib/blob/master/nircam_calib/training_notebooks/jwst_pipeline_walkthrough.ipynb

# Outstanding issues / room for improvement:

- Darks
- Stars too small
- Align SW & LW images
- CEERS improvements
- Color images

# Import general

In [None]:
import os
from os.path import expanduser
home = expanduser("~")
from glob import glob
import numpy as np
from astropy.io import fits

# Inputs

In [None]:
filt_to_process = 'F115W'
# Uncomment CEERS tweaks below next time you run this
obs_to_process = ['010', '020']  # 2 NIRCam epochs
module_to_process = 'A'
exposures_to_process = 'all'  # [1]  # e.g., 1, 2, 3, 4 -OR- 'all' to process all

In [None]:
# Specify the xml and pointing files exported from APT
APT_input_dir     = './inputs/'  # APT
APT_file = os.path.join(APT_input_dir, 'macs0647_NIRCam')
xml_file      = APT_file + '.xml'
pointing_file = APT_file + '.pointing'

In [None]:
# Source catalogs
target = 'MACS0647+7015'  # must correspond to observed target in APT file!!
# Otherwise may throw error when calculating catalog seed image
cat_dict = {target:{}}
galaxy_catalog_file = 'MACS0647_MIRAGE_galaxy_catalog_%s.cat' % filt_to_process
cat_dict[target]['galaxy'] = galaxy_catalog_file
#cat_dict[target]['point_source'] = 'imaging_example_data/ptsrc_catalog.cat'

In [None]:
reffile_defaults = 'crds'  # Reference file values: crds or crds_full_name
cosmic_rays = {'library': 'SUNMAX', 'scale': 1.0}  # Cosmic ray library and rate
background = 'medium'
pav3 = 12.5  # telescope roll angle
dates = '2022-10-31'  # won't be used by MIRAGE, but will be added to FITS headers

# Outputs

In [None]:
#output_dir     = './yaml/'    # yaml files
simulation_dir = './images/'  # simulated images
output_dir = simulation_dir

#yaml_dir  = './yaml/'    # yaml files
image_dir = './images/'  # simulated images
yaml_dir = image_dir

#datatype = 'linear, raw'  # Save both raw (for JWST pipeline) and linear (processed except for dark current subtraction)
datatype = 'raw'  # Save raw images only for JWST pipeline

# Import MIRAGE

In [None]:
import mirage
os.environ["CRDS_DATA"] = "$HOME/crds_cache"
os.environ["CRDS_SERVER_URL"] = "https://jwst-crds.stsci.edu"
os.environ["MIRAGE_DATA"] = "/ifs/jwst/wit/mirage_data"  # internal to STScI
mirage.__version__

In [None]:
if 0:  # for users external to STScI
    # Download 343 GB of files (will take some time!)
    from mirage.reference_files import downloader
    download_path = os.path.join(home, 'MIRAGE', 'data')
    os.makedirs(download_path, exist_ok=True)
    downloader.download_reffiles(download_path, instrument='nircam', dark_type='both',  # linearized
                                 skip_darks=False, skip_cosmic_rays=False, skip_psfs=False, skip_grism=True)
    
    os.environ["MIRAGE_DATA"] = download_path

In [None]:
# mirage imports
from mirage import imaging_simulator
from mirage.seed_image import catalog_seed_image
from mirage.dark import dark_prep
from mirage.ramp_generator import obs_generator
from mirage.yaml import yaml_generator

In [None]:
import yaml

# Import JWST Pipeline

In [None]:
import jwst
from jwst.pipeline import Detector1Pipeline, Image2Pipeline
from jwst.associations.lib.rules_level2_base import DMSLevel2bBase
from jwst.associations.lib.rules_level3_base import DMS_Level3_Base
from jwst.associations import asn_from_list
from jwst.pipeline import calwebb_image3
jwst.__version__

# Run

In [None]:
# Run the yaml generator (takes a minute or so)
yam = yaml_generator.SimInput(input_xml=xml_file, pointing_file=pointing_file, catalogs=cat_dict, 
                              cosmic_rays=cosmic_rays, background=background, roll_angle=pav3, dates=dates, 
                              reffile_defaults=reffile_defaults, datatype=datatype, verbose=True,
                              output_dir=yaml_dir, simdata_output_dir=image_dir)

# https://mirage-data-simulator.readthedocs.io/en/latest/dark_preparation.html#calibration-and-linearization
yam.use_linearized_darks = False

yam.create_inputs()

In [None]:
yam.use_linearized_darks

In [None]:
# 4 filters x 4 dithers x 10 detectors (as specified in APT)
yfiles = glob(os.path.join(yaml_dir,'jw*.yaml'))
yfiles = np.sort(yfiles)
len(yfiles)

In [None]:
# Print info about these files: filter and detector
for yamlfile in yfiles:
    with open(yamlfile, 'r') as infile:
        params = yaml.safe_load(infile)
    filt = params['Readout']['filter']
    detector = params['Readout']['array_name'][3:5]
    print(filt, detector, params['Output']['observation_number'], yamlfile)

In [None]:
#params

In [None]:
# Print info about these files: filter and detector
for yamlfile in yfiles:
    with open(yamlfile, 'r') as infile:
        params = yaml.safe_load(infile)
    filt = params['Readout']['filter']
    detector = params['Readout']['array_name'][3:5]
    if detector == 'A2':
        break
        
params

In [None]:
int(params['Output']['exposure_number'])

In [None]:
if   module_to_process == 'A':
    yfiles = [yfile for yfile in yfiles if '_nrca' in yfile]  # Select only images in NIRCam Module A
elif module_to_process == 'B':
    yfiles = [yfile for yfile in yfiles if '_nrcb' in yfile]  # Select only images in NIRCam Module B

#a5files = [yfile for yfile in yfiles if 'a5.' in yfile] # Select only Module A Long Wavelength images: detector A5

In [None]:
#exposures_to_process = [1]

In [None]:
# Select filter observations
yamls_to_process = []
for yamlfile in yfiles:
    with open(yamlfile, 'r') as infile:
        params = yaml.safe_load(infile)
    yaml_filt = params['Readout']['filter']
    obs_num = params['Output']['observation_number']
    if obs_num in obs_to_process:
        if yaml_filt == filt_to_process:
            if (exposures_to_process == 'all') or (int(params['Output']['exposure_number']) in exposures_to_process):
                yamls_to_process.append(yamlfile)
        
len(yamls_to_process)

In [None]:
yamls_to_process

In [None]:
# Only create images that haven't been created already
yamls = np.sort(yamls_to_process)
yamls_to_process = []
for yamlfile in yamls:
    outfits = yamlfile.replace('.yaml', '_uncal.fits')
    already_did_it = os.path.exists(outfits)
    havent_done_it_yet = not already_did_it
    if havent_done_it_yet:
        yamls_to_process.append(yamlfile)

In [None]:
len(yamls_to_process) 

In [None]:
yamls_to_process = np.sort(yamls_to_process)
yamls_to_process

# Create the simulated images (will take a while)

In [None]:
# 30 minutes per image
cal_images = []
for yfile in yamls_to_process:
    print(yfile)
    uncal_image = yfile.replace('.yaml', '_uncal.fits').replace(yaml_dir, image_dir)
    rate_image = uncal_image.replace('_uncal.fits', '_rate.fits')
    cal_image  = uncal_image.replace('_uncal.fits', '_cal.fits')
    cal_images.append(cal_image)

    # MIRAGE
    # 10 minutes per image
    if not os.path.exists(uncal_image):
        m = imaging_simulator.ImgSim()
        m.paramfile = yfile
        m.create()
        
    hdu = fits.open(uncal_image)
    try:
        noutputs = hdu[0].header['NOUTPUTS']
    except:
        hdu[0].header['NOUTPUTS'] = 4
        hdu.writeto(uncal_image, overwrite=True)

    # JWST Pipeline
    # 20 minutes per image
    if not os.path.exists(cal_image):
        result1 = Detector1Pipeline()
        result1.dark_current.skip = False
        result1.jump.rejection_threshold = 21
        result1.ipc.skip = False  # Correct for interpixel capicitance simulated by MIRAGE
        result1.persistence.skip = True  # Persistence not simulated by MIRAGE
        result1.save_results = True
        result1.output_dir = image_dir
        result1.run(uncal_image) # uncal -> rate
        #
        result2 = Image2Pipeline()
        result2.resample.skip = True  # Don't produce individual id2 images (rectified quick-look)
        result2.save_results = True
        result2.output_dir = image_dir
        result2.run(rate_image) # rate -> cal, id2

In [None]:
# jw01433010001_01101_00012_nrca5_uncal.fits
#glob('./images/*uncal.fits')

# Combine exposures into one image

In [None]:
cal_images = []
for yfile in yamls_to_process:
    uncal_image = yfile.replace('.yaml', '_uncal.fits').replace(output_dir, simulation_dir)
    rate_image = uncal_image.replace('_uncal.fits', '_rate.fits')
    cal_image  = uncal_image.replace('_uncal.fits', '_cal.fits')
    cal_images.append(cal_image)
    print(cal_image)

In [None]:
len(cal_images)

In [None]:
cal_images = glob('images/*_cal.fits')
len(cal_images)

In [None]:
[print(x) for x in cal_images];

In [None]:
association_file = 'MACS0647_%s_image_associations.json' % filt_to_process
association_file

In [None]:
association = asn_from_list.asn_from_list(cal_images, rule=DMS_Level3_Base, 
                                          product_name='MACS0647_'+filt_to_process,
                                          asn_rule='Asn_Image')
                                          #asn_type='image3')

with open(association_file, 'w') as fh:
   fh.write(association.dump()[1])

In [None]:
lam = int(filt_to_process[1:4])
channel = ['sw', 'lw'][lam > 235]    
filt_to_process, lam, channel

In [None]:
# Including recommendations from CEERS program DR1 ceers_nircam_reduction.ipynb
# 10 minutes for 16 SW images -> 1 image
# ## minutes for  4 SW images -> 1 image
m = calwebb_image3.Image3Pipeline()
m.tweakreg.skip = True  # Turn off TweakRegStep since these simulated images are perfectly aligned (no guide star uncertainties)
#m.skymatch.skip = True  # Turn off SkyMatchStep
m.outlier_detection.skip = False
m.source_catalog.snr_threshold = 5  # 20
m.source_catalog.output_file = "MACS0647_%s_cat.ecsv" % filt_to_process
m.save_results = True  # _id2.fits

if channel == 'sw':
    m.resample.pixel_scale_ratio = 0.015 / 0.031  # SW images 0.031" -> 0.015" / pix
elif channel == 'lw':
    m.resample.pixel_scale_ratio = 0.03 / 0.063  # LW images 0.063" -> 0.03" / pix
else:
    print('Unknown channel and pixel scale')

m.run(association_file)  # run the pipeline with these parameters on these images in association file
#m.save_model()  # generate id2.fits in case you didn't above

In [None]:
output_image = 'MACS0647_%s_i2d.fits' % filt_to_process

# Show results

In [None]:
%matplotlib notebook
#%matplotlib inline
import matplotlib.pyplot as plt
from astropy.visualization import simple_norm
#from scipy.stats import sigmaclip

def show(data, percent=99.6):
    plt.figure(figsize=(12,12))
    norm = simple_norm(data, 'asinh', percent=percent)
    plt.imshow(data,norm=norm)
    plt.colorbar().set_label('DN$^{-}$/s')

In [None]:
data = fits.open(output_image)['SCI'].data
print(data.shape)
show(data)
output_image

In [None]:
if 0:
    raw_data = fits.open(uncal_image)['SCI'].data
    print(raw_data.shape)
    data = 1. * raw_data[0, 3, :, :] - 1. * raw_data[0, 0, :, :]
    show(data)
    raw_image_file

In [None]:
# If linear images were saved
if 0:
    linear_image_file = yfile.replace('.yaml', '_linear.fits').replace(output_dir, simulation_dir)
    linear_data = fits.open(linear_image_file)['SCI'].data
    show(linear_data[0, 3, :, :])
    linear_image_file

The raw data file is now ready to be run through the [JWST calibration pipeline](https://jwst-pipeline.readthedocs.io/en/stable/) from the beginning. If dark current subtraction is not important for you, you can use Mirage's linear output, skip some of the initial steps of the pipeline, and begin by running the [Jump detection](https://jwst-pipeline.readthedocs.io/en/stable/jwst/jump/index.html?highlight=jump) and [ramp fitting](https://jwst-pipeline.readthedocs.io/en/stable/jwst/ramp_fitting/index.html) steps.

---
<a id='mult_sims'></a>
## Simulating Multiple Exposures

Each yaml file will simulate an exposure for a single pointing using a single detector. To simulate multiple exposures, or a single exposure with multiple detectors, multiple calls to the *imaging_simulator* must be made.

### In Series
```python
paramlist = [yaml_a1,yaml_a2,yaml_a3,yaml_a4,yaml_a5]

def many_sim(paramlist):
    '''Function to run many simulations in series
    '''
    for file in paramlist:
        m = imaging_simulator.ImgSim()
        m.paramfile = file
        m.create()
```

### In Parallel

Since each `yaml` simulations does not depend on the others, we can parallelize the process to speed things up:
```python
from multiprocessing import Pool

n_procs = 5 # number of cores available

with Pool(n_procs) as pool:
    pool.map(make_sim, paramlist)
```

In [None]:
# https://techwiser.com/how-many-cores-does-my-cpu-have/
n_procs = 6 # number of cores available