In [1]:
import numpy as np
from asdf import AsdfFile
from astropy.io import fits
from astropy import wcs as astwcs

from gwcs import wcs
from jwst import datamodels
from jwst.assign_wcs import nirspec
from jwst.transforms import models

In [2]:
#Create reference file dict
def create_reference_files(datamodel):
    """
    Create a dict {reftype: reference_file}.
    """
    refs = {}
    step = assign_wcs_step.AssignWcsStep()
    for reftype in assign_wcs_step.AssignWcsStep.reference_file_types:
        refs[reftype] = step.get_reference_file(datamodel, reftype)
    return refs

# These are the CV3 files
refs = {'camera': '/grp/crds/cache/references/jwst/jwst_nirspec_camera_0004.asdf',
        'collimator': '/grp/crds/cache/references/jwst/jwst_nirspec_collimator_0004.asdf',
        'disperser': '/grp/crds/cache/references/jwst/jwst_nirspec_disperser_0035.asdf',
        'distortion': 'N/A',
        'filteroffset': 'N/A',
        'fore': '/grp/crds/cache/references/jwst/jwst_nirspec_fore_0022.asdf',
        'fpa': '/grp/crds/cache/references/jwst/jwst_nirspec_fpa_0005.asdf',
        'ifufore': '/grp/crds/cache/references/jwst/jwst_nirspec_ifufore_0003.asdf',
        'ifupost': 'ifupost.asdf',
        'ifuslicer': '/grp/crds/cache/references/jwst/jwst_nirspec_ifuslicer_0003.asdf',
        'msa': '/grp/crds/cache/references/jwst/jwst_nirspec_msa_0005.asdf',
        'ote': 'ote.asdf',
        'regions': 'N/A',
        'specwcs': 'N/A',
        'wavelengthrange': '/grp/crds/cache/references/jwst/jwst_nirspec_wavelengthrange_0004.asdf'}


In [3]:
wcs_kw = {'wcsaxes': 2, 'ra_ref': 165, 'dec_ref': 54,
          'v2_ref': -8.3942412, 'v3_ref': -5.3123744, 'roll_ref': 37,
          'crpix1': 1024, 'crpix2': 1024,
          'cdelt1': .08, 'cdelt2': .08,
          'ctype1': 'RA---TAN', 'ctype2': 'DEC--TAN',
          'pc1_1': 1, 'pc1_2': 0, 'pc2_1': 0, 'pc2_2': 1
          }

slit_fields_num = ["shutter_id", "xcen", "ycen",
                   "ymin", "ymax", "quadrant", "source_id",
                   "stellarity", "source_xpos", "source_ypos"]


slit_fields_str = ["name", "shutter_state", "source_name", "source_alias"]


In [4]:
def create_hdul(detector='NRS1'):
    """
    Create a fits HDUList instance.
    """
    hdul = fits.HDUList()
    phdu = fits.PrimaryHDU()
    phdu.header['instrume'] = 'NIRSPEC'
    phdu.header['detector'] = detector
    phdu.header['time-obs'] = '8:59:37'
    phdu.header['date-obs'] = '2016-09-05'

    scihdu = fits.ImageHDU()
    scihdu.header['EXTNAME'] = "SCI"
    for item in wcs_kw.items():
        scihdu.header[item[0]] = item[1]
    hdul.append(phdu)
    hdul.append(scihdu)
    return hdul

In [5]:
def create_nirspec_fs_file(grating, filter, lamp="N/A", detector='NRS1'):
    image = create_hdul(detector)
    image[0].header['exp_type'] = 'NRS_FIXEDSLIT'
    image[0].header['filter'] = filter
    image[0].header['grating'] = grating
    image[0].header['lamp'] = lamp
    image[1].header['crval3'] = 0
    image[1].header['wcsaxes'] = 3
    image[1].header['ctype3'] = 'WAVE'
    image[0].header['GWA_XTIL'] = 0.3316612243652344
    image[0].header['GWA_YTIL'] = 0.1260581910610199
    image[0].header['SUBARRAY'] = "FULL"
    return image



In [6]:
hdul = create_nirspec_fs_file(grating='G395H', filter='F290LP',  detector='NRS1')
im = datamodels.ImageModel(hdul)

pipeline = nirspec.create_pipeline(im, refs)
w = wcs.WCS(pipeline)
im.meta.wcs = w


2018-04-28 10:42:53,510 - stpipe - INFO - gwa_ytilt is 0.1260581910610199 deg
2018-04-28 10:42:53,511 - stpipe - INFO - gwa_xtilt is 0.3316612243652344 deg
2018-04-28 10:42:53,512 - stpipe - INFO - theta_y correction: -0.009545474118238594 deg
2018-04-28 10:42:53,514 - stpipe - INFO - theta_x correction: 0.0 deg
2018-04-28 10:42:56,992 - stpipe - INFO - Removing slit S200B1 from the list of open slits because the WCS bounding_box is completely outside the detector.
2018-04-28 10:42:56,994 - stpipe - INFO - Slits projected on detector NRS1: ['S200A1', 'S200A2', 'S400A1', 'S1600A1']
2018-04-28 10:42:56,995 - stpipe - INFO - Computing WCS for 4 open slitlets
2018-04-28 10:42:57,119 - stpipe - INFO - gwa_ytilt is 0.1260581910610199 deg
2018-04-28 10:42:57,120 - stpipe - INFO - gwa_xtilt is 0.3316612243652344 deg
2018-04-28 10:42:57,121 - stpipe - INFO - theta_y correction: -0.009545474118238594 deg
2018-04-28 10:42:57,124 - stpipe - INFO - theta_x correction: 0.0 deg
2018-04-28 10:42:57,21

In [7]:
# Setup the test
slitx = [0] * 5
slity = [-.5, -.25, 0, .25, .5]
wave_range = [2.87e-06, 5.27e-06]
lam = np.array([2.9, 3.39, 3.88, 4.37, 5]) * 10**-6

# Use slit S200A1
slit_wcs = nirspec.nrs_wcs_set_input(im, 'S200A1')

In [8]:
# Slit to MSA absolute
slit2msa = slit_wcs.get_transform('slit_frame', 'msa_frame')
msax, msay, _= slit2msa(slitx, slity, lam)
print('slitx: ', slitx )
print('slity: ', slity)
print('msax: ', msax)
print('msay: ', msay)


slitx:  [0, 0, 0, 0, 0]
slity:  [-0.5, -0.25, 0, 0.25, 0.5]
msax:  [0.02697243 0.02697243 0.02697243 0.02697243 0.02697243]
msay:  [-0.00335229 -0.00303449 -0.0027167  -0.00239891 -0.00208112]


In [9]:
# Coordinates at Collimator exit
# Applies the Collimator forward transform to MSa absolute coordinates
col = datamodels.open(refs['collimator'])
colx, coly = col.model(msax, msay)
print('x_collimator_exit', colx)
print('y_collimator_exit', coly)
col.close()

x_collimator_exit [0.01481493 0.01480212 0.01478927 0.01477638 0.01476345]
y_collimator_exit [0.17838953 0.17855393 0.17871813 0.17888213 0.17904594]


In [10]:
# MSA to GWA entrance
# This runs the Collimator forward, Unitless to Directional cosine, and 3D Rotation
# It uses the corrected GWA tilt value
disp = datamodels.DisperserModel(refs['disperser'])
disperser = nirspec.correct_tilt(disp, im.meta.instrument.gwa_xtilt, im.meta.instrument.gwa_ytilt)
collimator2gwa = nirspec.collimator_to_gwa(refs, disperser)
x_gwa_in, y_gwa_in, z_gwa_in = collimator2gwa(msax, msay)
disp.close()
print('x_gwa_entrance:' , x_gwa_in)
print('y_gwa_entrance:' , y_gwa_in)
print('z_gwa_entrance:' , z_gwa_in)

2018-04-28 10:43:03,012 - stpipe - INFO - gwa_ytilt is 0.1260581910610199 deg
2018-04-28 10:43:03,013 - stpipe - INFO - gwa_xtilt is 0.3316612243652344 deg
2018-04-28 10:43:03,014 - stpipe - INFO - theta_y correction: -0.009545474118238594 deg
2018-04-28 10:43:03,017 - stpipe - INFO - theta_x correction: 0.0 deg


x_gwa_entrance: [0.18738332 0.18740684 0.18743033 0.18745378 0.1874772 ]
y_gwa_entrance: [-0.28680741 -0.28634706 -0.28588663 -0.2854261  -0.28496549]
z_gwa_entrance: [0.93948337 0.93961909 0.9397546  0.9398899  0.94002498]


In [11]:
# Slit to GWA out
slit2gwa = slit_wcs.get_transform('slit_frame', 'gwa')
x_gwa_out, y_gwa_out, z_gwa_out = slit2gwa(slitx, slity, lam)
print('x_gwa_exit:' , x_gwa_out)
print('y_gwa_exit:' , y_gwa_out)
print('z_gwa_exit:' , z_gwa_out)

x_gwa_exit: [0.07190191 0.11568865 0.15947543 0.20326224 0.2595663 ]
y_gwa_exit: [0.28680741 0.28634706 0.28588663 0.2854261  0.28496549]
z_gwa_exit: [0.95528615 0.95111592 0.94490022 0.93659831 0.92272423]


In [12]:
# CAMERA entrance (assuming direction is from sky to detector)
angles = [disperser['theta_x'], disperser['theta_y'],
          disperser['theta_z'], disperser['tilt_y']]
rotation = models.Rotation3DToGWA(angles, axes_order="xyzy", name='rotation')
dircos2unitless = models.DirCos2Unitless()

gwa2cam = rotation.inverse | dircos2unitless
x_camera_entrance, y_camera_entrance = gwa2cam(x_gwa_out, y_gwa_out, z_gwa_out)
print('x_camera_entrance:' , x_camera_entrance)
print('y_camera_entrance:' , y_camera_entrance)

x_camera_entrance: [-0.07707064 -0.03101948  0.01514077  0.06171155  0.12272885]
y_camera_entrance: [0.30046859 0.29882465 0.29781475 0.29744381 0.29808709]


In [13]:
# at FPA
camera = datamodels.CameraModel(refs['camera'])
x_fpa, y_fpa = camera.model.inverse(x_camera_entrance, y_camera_entrance)
camera.close()
print('x_fpa: ', x_fpa )
print('y_fpa: ' , y_fpa)

x_fpa:  [-0.02239022 -0.00937323  0.00372075  0.01691372  0.03406922]
y_fpa:  [0.00142652 0.00119317 0.00103662 0.0009563  0.00101618]


In [14]:
# at SCA
slit2sca = slit_wcs.get_transform('slit_frame', 'sca')
x_sca_nrs1, y_sca_nrs1 = slit2sca(slitx, slity, lam)
#fpa = datamodels.open(refs['fpa'])
#print('fpax, fpay', fpa.nrs1_model.inverse(x_fpa, y_fpa))
# At NRS2
fpa = datamodels.FPAModel(refs['fpa'])
x_sca_nrs2, y_sca_nrs2 = fpa.nrs2_model.inverse(x_fpa, y_fpa)
fpa.close()
print('x_sca1: ', x_sca_nrs1)
print('y_sca1: ' , y_sca_nrs1)
print('x_sca2: ', x_sca_nrs2 )
print('y_sca2: ' , y_sca_nrs2)

x_sca1:  [ 876.70094739 1599.86683035 2327.31023955 3060.25306449 4013.33665474]
y_sca1:  [1102.75121505 1089.78694793 1081.08997251 1076.6275921  1079.95467952]
x_sca2:  [3364.51728613 2641.35159688 1913.90831766 1180.96555945  227.88191961]
y_sca2:  [944.16217632 957.13724461 965.84508511 970.31841273 967.00556052]


In [15]:
# at oteip
slit2oteip = slit_wcs.get_transform('slit_frame', 'oteip')
x_oteip, y_oteip, _ = slit2oteip(slitx, slity, lam)
print('x_oteip: ', x_oteip)
print('y_oteip: ' , y_oteip)

x_oteip:  [0.02964448 0.02999172 0.03033894 0.03068614 0.03103378]
y_oteip:  [-0.03278024 -0.03238316 -0.03198612 -0.03158912 -0.0311919 ]


In [16]:
# at v2, v3 [in arcsec]
slit2v23 = slit_wcs.get_transform('slit_frame', 'v2v3')
v2, v3, _ = slit2v23(slitx, slity, lam)
v2 /= 3600
v3 /= 3600
print('v2: ', v2)
print('v3: ' , v3)

v2:  [0.09254187 0.09238998 0.09223809 0.09208622 0.09193416]
v3:  [-0.13344017 -0.13326734 -0.13309453 -0.13292174 -0.13274886]


In [17]:
# Save results to an asdf file

fa = AsdfFile()
fa.tree['slitx'] = list(slitx)
fa.tree['slity'] = list(slity)
fa.tree['lam'] = list(lam)
fa.tree['msax'] = list(msax)
fa.tree['msay'] = list(msay)
fa.tree['x_collimator_exit'] = list(colx)
fa.tree['y_collimator_exit'] = list(coly)
fa.tree['x_gwa_entrance'] = list(x_gwa_in)
fa.tree['y_gwa_entrance'] = list(y_gwa_in)
fa.tree['z_gwa_entrance'] = list(z_gwa_in)
fa.tree['x_gwa_exit'] = list(x_gwa_out)
fa.tree['y_gwa_exit'] = list(y_gwa_out)
fa.tree['z_gwa_exit'] = list(z_gwa_out)
fa.tree['x_camera_entrance'] = list(x_camera_entrance)
fa.tree['y_camera_entrance'] = list(y_camera_entrance)
fa.tree['x_fpa'] = list(x_fpa)
fa.tree['y_fpa'] = list(y_fpa)
fa.tree['x_sca_nrs1'] = list(x_sca_nrs1)
fa.tree['y_sca_nrs1'] = list(y_sca_nrs1)
fa.tree['x_sca_nrs2'] = list(x_sca_nrs2)
fa.tree['y_sca_nrs2'] = list(y_sca_nrs2)
fa.tree['x_oteip'] = list(x_oteip)
fa.tree['y_oteip'] = list(y_oteip)
fa.tree['v2'] = list(v2)
fa.tree['v3'] = list(v3)
fa.write_to("fixed_slits_functional.asdf", all_array_storage="internal")