In [1]:
# generic imports
import numpy as np
import astropy.io.fits as pyfits

In [2]:
# had to muck with my pythonpath to get my own pyast fork in place.
import os
print os.environ['PYTHONPATH']

/Users/parejkoj/lsst/lsstsw/lsst_build/python:/Users/parejkoj/lsst/temp/starlink-pyast:/Users/parejkoj/dev/python:/Users/parejkoj/lsst/lsstsw/eups/1.5.9/python


In [3]:
# LSST imports

In [4]:
# AST imports
from starlink import Ast
from starlink import Atl

In [5]:
# astropy imports
import gwcs
from gwcs import coordinate_frames as cf
from astropy.modeling import models
from astropy import units as u
from astropy import coordinates as coord

In [28]:
# infile = '/Users/parejkoj/lsst/afwdata/CFHT/D4/cal-53535-i-797722_1.fits'
infile = '/Users/parejkoj/lsst/simastrom/validation_data_cfht/raw/849375p.fits.fz'
data = pyfits.open(infile)
data[1].data.shape
nx = data[1].shape[0]
ny = data[1].shape[1]

In [7]:
# make some gWCS objects from the file

# get the basic WCS transformation from the FITS headers
# Have to fake a keyword: https://github.com/spacetelescope/gwcs/issues/40
data[1].header['WCSAXES'] = 2
fits_transform = gwcs.utils.make_fitswcs_transform(data[1].header)

# Create a simple distortion model from two 2D polynomails
x_distort = models.Polynomial2D(2, c1_0=1, c0_1=1)
y_distort = models.Polynomial2D(2, c1_0=1, c0_1=1)
distortion = models.Mapping((0,1,0,1)) | x_distort & y_distort
# A potentially useful test coordinate is the identity.
distortion = models.Identity(2)

# Create some coordinate frames to map between
detector = cf.Frame2D(name='detector', axes_order=(0,1), unit=(u.pix, u.pix))
focal = cf.Frame2D(name='focal', axes_order=(0,1), unit=(u.arcmin, u.arcmin))
sky = cf.CelestialFrame(name='icrs', reference_frame=coord.ICRS())

# tuples of frame:mapping
# The last frame has None for the transform.
pipeline = [(detector, distortion),
                (focal, fits_transform),
                (sky, None)
                ]
wcs = gwcs.wcs.WCS(pipeline)

# print the on-sky coordinates of some pixels, and the corners
print wcs
print wcs(0,0)
print wcs(100,100)
print wcs.footprint(data[1].data.shape)

  From   Transform
-------- ---------
detector      None
   focal      None
    icrs      None
(215.50450967846268, 53.151131818465764)
(215.51298666101513, 53.145896269730315)
[[ 215.50459446   53.15107947]
 [ 215.50363274   53.0427071 ]
 [ 215.89833684   53.03734126]
 [ 215.900288     53.14570507]]


In [15]:
# make some AST objects from the file.
data = pyfits.open(infile)
frameset,encoding = Atl.readfitswcs(data[1], Iwc=True)

# TBD: need to figure out how to construct a PolyMap.
# TBD: this is from the pyast test.py file:
# pm = Ast.PolyMap( [[1.2,1.,2.,0.],[-0.5,1.,1.,1.],[1.0,2.,0.,1.]])
# UnitMap is the Identity map: input a coordinate and just output it.
distortion = Ast.UnitMap(2)

# The original mapping. simplify() it to get rid of loops, etc.
map = frameset.getmapping(Ast.BASE, frameset.Nframe).simplify()
newmap = Ast.CmpMap(map, distortion)

# insert the new mapping
pixframe = frameset.getframe(Ast.BASE) #pointer to the base frame (pixel coords)
newmap.invert()
isky = frameset.Current
frameset.addframe(frameset.Nframe, newmap, pixframe)
newpix = frameset.Current
frameset.Current = isky

# Now delete the old frame. Why, I'm not sure.
oldpix = frameset.Base
frameset.Base = newpix
frameset.removeframe(oldpix)

# print the on-sky coordinates of some pixels, and the corners
xpixel = [0,100]
ypixel = [0,100]
ra,dec = frameset.tran([xpixel,ypixel])
print ra,dec
print (180/np.pi)*frameset.norm((ra[0],dec[0]))

#print frameset

[-2.52192206 -2.52177411] [ 0.92766225  0.92757087]
[ 215.50450968   53.15113182]


In [16]:
# make a meshgrid of the pixel coordinates
xx = np.arange(0,nx)
yy = np.arange(0,ny)
xv,yv = np.meshgrid(xx,yy)

In [25]:
# How do these two match up in terms of processing time?
%timeit wcs(xv,yv)
%timeit frameset.tran([xv.flatten(),yv.flatten()])

1 loops, best of 3: 2.41 s per loop
1 loops, best of 3: 2.62 s per loop


In [87]:
result_gwcs = wcs(xv,yv)
result_ast = (180/np.pi)*frameset.tran([xv.flatten(),yv.flatten()])
result_ast = (result_ast[0].reshape(ny,nx),result_ast[1].reshape(ny,nx))

In [91]:
np.allclose(result_ast[0] - result_gwcs[0]
result_ast[1] - result_gwcs[1]

array([[  7.10542736e-15,   7.10542736e-15,   7.10542736e-15, ...,
          7.10542736e-15,   7.10542736e-15,   1.42108547e-14],
       [  7.10542736e-15,   1.42108547e-14,   7.10542736e-15, ...,
          0.00000000e+00,   0.00000000e+00,   7.10542736e-15],
       [  0.00000000e+00,   0.00000000e+00,   7.10542736e-15, ...,
          7.10542736e-15,   0.00000000e+00,   0.00000000e+00],
       ..., 
       [  0.00000000e+00,   0.00000000e+00,   0.00000000e+00, ...,
          7.10542736e-15,   0.00000000e+00,   0.00000000e+00],
       [  7.10542736e-15,   7.10542736e-15,   7.10542736e-15, ...,
          0.00000000e+00,  -7.10542736e-15,   0.00000000e+00],
       [  0.00000000e+00,   7.10542736e-15,   7.10542736e-15, ...,
          0.00000000e+00,   0.00000000e+00,   0.00000000e+00]])