In [None]:
import galsim
import numpy as np
import matplotlib.pyplot as plt
from astropy.table import Table, vstack
from astropy.time import Time
import astropy.units as u
import esutil
import tqdm
import ssa

In [None]:
# Distortion and vignetting functions for a 0.4 m CDK telescope similar to
# https://www.cloudynights.com/topic/587494-kevins-27-corrected-dall-kirkham/
th = np.array([0.        , 0.03448276, 0.06896552, 0.10344828, 0.13793103,
       0.17241379, 0.20689655, 0.24137931, 0.27586207, 0.31034483,
       0.34482759, 0.37931034, 0.4137931 , 0.44827586, 0.48275862,
       0.51724138, 0.55172414, 0.5862069 , 0.62068966, 0.65517241,
       0.68965517, 0.72413793, 0.75862069, 0.79310345, 0.82758621,
       0.86206897, 0.89655172, 0.93103448, 0.96551724, 1.        ])
unvig = np.array([0.86972549, 0.87259059, 0.87087748, 0.86941706, 0.86827994,
       0.86648507, 0.86486115, 0.86346019, 0.86065827, 0.85650741,
       0.85182144, 0.82975165, 0.78771911, 0.73632577, 0.67956641,
       0.61945143, 0.55733349, 0.5019528 , 0.45624505, 0.41441689,
       0.37429813, 0.3345807 , 0.29365183, 0.24876162, 0.19516092,
       0.1444885 , 0.09833854, 0.05772925, 0.02454469, 0.00247119])
dthdr = np.array([0.06124069, 0.06126784, 0.06124352, 0.06120296, 0.06114596,
       0.06107225, 0.06098148, 0.06087321, 0.06074759, 0.06060343,
       0.0604463 , 0.06029759, 0.06011246, 0.05992273, 0.05973642,
       0.05947572, 0.05919784, 0.05893438, 0.05865396, 0.05834981,
       0.05815874, 0.05775472, 0.05736799, 0.05702492, 0.0565666 ,
       0.0563616 , 0.05585923, 0.05531151, 0.05477258, 0.05446725])  # arcsec / micron => ~ 1.5 arcsec per 25 micron pixel

In [None]:
def polyWCS(th, dthdr, world_origin, theta=0, n=10, order=3):
    """
    Make a WCS from distortion polynomial
    
    
    Parameters:
        th:  Field angles in degrees
        dthdr:  Radial plate scale in arcsec per pixel
        world_origin:  Origin of ra,dec
        theta: Rotation angle in radians
        n: number of control points to use
        order: order of SIP part of fitted WCS
    """
    from scipy.interpolate import interp1d
    from scipy.integrate import quad, IntegrationWarning
    import warnings

    u = np.deg2rad(np.linspace(-th[-1], th[-1], 10))
    u, v = np.meshgrid(u, u)
    rho = np.hypot(u, v)
    w = rho <= np.deg2rad(th[-1])
    u = u[w]
    v = v[w]
    rho = rho[w]

    interp = interp1d(th, dthdr, kind='cubic')  # deg -> arcsec/pix
    integrand = lambda arcsec: 1./interp(arcsec/3600)  # arcsec -> pix/arcsec

    x = np.empty_like(u)
    y = np.empty_like(u)

    def _ufunc(self, x, y):
        x, y = self._R @ np.array([x, y])
        return self._scale*x*(1. + self._r3 * (x**2 + y**2))

    def _vfunc(self, x, y):
        x, y = self._R @ np.array([x, y])
        return self._scale*y*(1. + self._r3 * (x**2 + y**2))


    for idx in np.ndindex(u.shape):    
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=IntegrationWarning)
            r = quad(integrand, 0, np.rad2deg(rho[idx])*3600)[0]
        x[idx] = r*u[idx]/rho[idx]
        y[idx] = r*v[idx]/rho[idx]
        
    sth, cth = np.sin(theta), np.cos(theta)
    R = np.array([[cth, -sth], [sth, cth]])
    x, y = R @ np.array([x, y])

    ra, dec = world_origin._deproject(u, v, projection='postel')
    wcs = galsim.FittedSIPWCS(x, y, ra, dec, order=order)
    return wcs

In [None]:
# class CubicWCS(galsim.wcs.CelestialWCS):
#     """WCS that includes a cubic radial term along with a 'gnomonic' projection to sky coordinates.

#     This wcs goes from x,y -> u,v -> ra, dec

#     The x,y to u,v transformation is determined by
#         u = scale*x*(1 + r3*(x^2 + y^2)),
#         v = scale*y*(1 + r3*(x^2 + y^2))

#     Where scale will be the the pixel scale at the center of the image and r3 determines
#     the amount of radial distortion.  For a value of r3=3e-8 the pixel scale changes 
#     by ~10% at a radius of ~ 1400.

#     The inverse from u,v -> x,y is done via Cardano's method for solving cubic equations.

#     Note if an image with this wcs is written to a file, galsim cannot currently read the
#     file because all wcs classes are expected to be in the galsim namespace

#     Parameters:
#         scale:            The nominal pixel scale
#         r3:               The amount of radial distortion
#         origin:           Origin position for the image coordinate system.
#         world_origin:     Origin position in ra,dec
#         theta:            Rotation angle in radians
#     """

#     def __init__(self, scale, r3, origin, world_origin, theta=0.0):
#         self._scale = scale
#         self._r3 = r3
#         self._color = None
#         self._set_origin(origin)
#         self._world_origin = world_origin
#         self._theta = theta
#         sth, cth = np.sin(theta), np.cos(theta)
#         self._R = np.array([[cth, -sth], [sth, cth]])

#         self._q = 1./(self._r3*self._scale)
#         self._p = 1./self._r3

#         self._torad = galsim.arcsec / galsim.radians

    
#     def _ufunc(self, x, y):
#         x, y = self._R @ np.array([x, y])
#         return self._scale*x*(1. + self._r3 * (x**2 + y**2))

#     def _vfunc(self, x, y):
#         x, y = self._R @ np.array([x, y])
#         return self._scale*y*(1. + self._r3 * (x**2 + y**2))

# #     def _xfunc(self, u, v):
# #         wsq = u*u + v*v
# #         if wsq == 0.:
# #             return 0.
# #         else:
# #             w = np.sqrt(wsq)
# #             temp = (np.sqrt(self._q**2*wsq/4 + self._p**3/27))
# #             r = (temp + self._q*w/2)**(1./3) - (temp - self._q*w/2)**(1./3)
# #             x = u*r/w

# #     def _yfunc(self, u, v):
# #         wsq = u*u + v*v
# #         if wsq == 0.:
# #             return 0.
# #         else:
# #             w = np.sqrt(wsq)
# #             temp = (np.sqrt(self._q**2*wsq/4 + self._p**3/27))
# #             r = (temp + self._q*w/2)**(1./3) - (temp - self._q*w/2)**(1./3)
# #             y = v*r/w

#     def _radec_func(self, x, y):
#         return self._world_origin.deproject_rad(self._ufunc(x, y)*self._torad,
#                                                 self._vfunc(x, y)*self._torad)

#     def _xy_func(self, ra, dec):
#         u, v = self._world_origin.project_rad(ra, dec)

#         u /= self._torad
#         v /= self._torad
#         wsq = u*u + v*v
#         if wsq == 0.:
#             return 0.
#         else:
#             w = np.sqrt(wsq)
#             temp = (np.sqrt(self._q**2*wsq/4 + self._p**3/27))
#             r = (temp + self._q*w/2)**(1./3) - (temp - self._q*w/2)**(1./3)
#             x = u*r/w
#             y = v*r/w
#         return self._R.T @ np.array([x, y])

#     @property
#     def origin(self):
#         """The input radec_func
#         """
#         return self._origin

#     @property
#     def world_origin(self):
#         """The input radec_func
#         """
#         return self._world_origin

#     @property
#     def radec_func(self):
#         """The input radec_func
#         """
#         return self._radec_func

#     @property
#     def xy_func(self):
#         """The input xy_func
#         """
#         return self._xy_func

#     def _radec(self, x, y, color=None):
#         try:
#             return self._radec_func(x, y)
#         except Exception as e:
#             try:
#                 world = [self._radec(x1, y1) for (x1, y1) in zip(x, y)]
#             except Exception:  # pragma: no cover
#                 # Raise the original one if this fails, since it's probably more relevant.
#                 raise e
#             ra = np.array([w[0] for w in world])
#             dec = np.array([w[1] for w in world])
#             return ra, dec

#     def _xy(self, ra, dec, color=None):
#         try:
#             return self._xy_func(ra, dec)
#         except Exception as e:
#             try:
#                 image = [self._xy(ra1, dec1) for (ra1, dec1) in zip(ra, dec)]
#             except Exception:  # pragma: no cover
#                 # Raise the original one if this fails, since it's probably more relevant.
#                 raise e
#             x = np.array([w[0] for w in image])
#             y = np.array([w[1] for w in image])
#             return x, y

#     def _newOrigin(self, origin, world_origin):
#         return CubicWCS(self._scale, self._r3, origin, world_origin)

#     def _withOrigin(self, origin, world_origin, color):
#         return self._newOrigin(origin, world_origin)

#     def _writeHeader(self, header, bounds):
#         header["GS_WCS"] = ("CubicWCS", "GalSim WCS name")
#         header["GS_X0"] = (self.origin.x, "GalSim image origin x")
#         header["GS_Y0"] = (self.origin.y, "GalSim image origin y")
#         header["GS_RA0"] = (self.world_origin.ra.deg, "GalSim world origin x")
#         header["GS_DEC0"] = (self.world_origin.dec.deg,
#                              "GalSim world origin x")
#         header["GS_SCALE"] = (self._scale, "Nominal pixel scale")
#         header["GS_R3"] = (self._r3, "Cubic coefficient")

#         return self.affine(bounds.true_center)._writeLinearWCS(header, bounds)

#     @staticmethod
#     def _readHeader(header):
#         x0 = header["GS_X0"]
#         y0 = header["GS_Y0"]
#         ra0 = header["GS_RA0"]
#         dec0 = header["GS_DEC0"]
#         scale = header['GS_SCALE']
#         r3 = header['GS_R3']

#         return CubicWCS(scale, r3, galsim.PositionD(x0, y0),
#                         galsim.CelestialCoord(ra0*galsim.degrees, dec0*galsim.degrees))
#         return None

#     def copy(self):
#         ""
#         return CubicWCS(self._scale, self._r3, self.origin, self.world_origin)

#     def __eq__(self, other):
#         ""
#         return (self is other or
#                 (isinstance(other, CubicWCS) and
#                  self._scale == other._scale and
#                  self._xy_r3 == other._r3 and
#                  self.origin == other.origin and
#                  self.world_origin == other.world_origin
#                  ))

#     def __repr__(self):
#         return "galsim.CubicWCS(%r, %r, %r %r)" % (self._scale, self._r3, self.origin, self.world_origin)

#     def __hash__(self): return hash(repr(self))

#     def __getstate__(self):
#         d = self.__dict__.copy()
#         return d

#     def __setstate__(self, d):
#         self.__dict__ = d

In [None]:
# # Round trip test
# wcs = CubicWCS(0.2, 1e-7, galsim.PositionD(0,0), galsim.CelestialCoord(0*galsim.radians, 0*galsim.radians), theta=1)

# rng = np.random.default_rng(57721)
# x = rng.uniform(-1000, 1000, size=10_000)
# y = rng.uniform(-1000, 1000, size=10_000)

# ra, dec = wcs.xyToradec(x, y, units='rad')
# x1, y1 = wcs.radecToxy(ra, dec, units='rad')
# print(np.max(np.abs(x1 - x)))
# print(np.max(np.abs(y1 - y)))

In [None]:
# These are LSST numbers, but should be okay to roughly scale against
# although, LSST is probably a significantly darker site than many SSA sites...

A = 319/9.6  # etendue / FoV.  I *think* this includes vignetting

# zeropoints from David Kirkby notes in photons per second per pixel
s0 = {'u': A*0.732,
      'g': A*2.124,
      'r': A*1.681,
      'i': A*1.249,
      'z': A*0.862,
      'y': A*0.452}
# Sky brightnesses in AB mag / arcsec^2.
# from http://www.lsst.org/files/docs/gee_137.28.pdf
B = {'u': 22.8,
     'g': 22.2,
     'r': 21.3,
     'i': 20.3,
     'z': 19.1,
     'y': 18.1}

# Rescale LSST area to ~1m area (with 0.1 fractional obscuration)
area_ratio = (1**2-0.1**2)/(8.36**2*(1-0.61**2))

# Sky brightness per arcsec^2 per second
sbar = {k:area_ratio*s0[k] * 10**(-0.4*(B[k]-24)) for k in B}

In [None]:
# boresight_ra = 0.0  # deg
# boresight_dec = 0.0  # deg
# boresight_rot = np.deg2rad(0)
# world_origin = galsim.CelestialCoord(boresight_ra*galsim.degrees, boresight_dec*galsim.degrees)
# pixel_scale = 2.0  # arcsec / pixel

# Setup a simple WCS, we can distort it in a minute.
# sth, cth = np.sin(boresight_rot), np.cos(boresight_rot)
# R = pixel_scale*np.array([[cth, -sth], [sth, cth]])
# wcs = galsim.TanWCS(galsim.AffineTransform(*R.ravel()), world_origin)
# wcs = CubicWCS(pixel_scale, 1e-6, galsim.PositionD(0, 0), world_origin)

boresight_ra = 0.0  # deg
boresight_dec = 0.0  # deg
boresight_rot = np.deg2rad(0)
world_origin = galsim.CelestialCoord(boresight_ra*galsim.degrees, boresight_dec*galsim.degrees)

pixSize = 25 # micron/pixel
wcs = polyWCS(th, dthdr*pixSize, world_origin, theta=0)

In [None]:
def makeImage(bounds, wcs, background, endpoints, time, observer, seed, fixAngles=False, exptime=1.0):
    # Need time and site for aberration corrections
    from scipy.interpolate import interp1d

    rng = galsim.BaseDeviate(seed)
    image = galsim.Image(bounds)
    
    wcs.makeSkyImage(image, sbar['r']*exptime)
    noise = galsim.PoissonNoise(rng)
    image.addNoise(noise)

    max_dist = 0*galsim.radians
    center = wcs.posToWorld(image.true_center)

    htm = esutil.htm.HTM(7)
    corners = [
        (image.bounds.xmin, image.bounds.ymin),
        (image.bounds.xmin, image.bounds.ymax),
        (image.bounds.xmax, image.bounds.ymin),
        (image.bounds.xmax, image.bounds.ymax)
    ]
    for x, y in corners:
        sky = wcs.posToWorld(galsim.PositionD(x, y))
        dist = center.distanceTo(sky)
        if dist > max_dist:
            max_dist = dist
    radius = max_dist.deg + 0.1
    shards = htm.intersect(center.ra.deg, center.dec.deg, radius)

    gaia_dir = "/Users/meyers18/data/Gaia/"
    table_list = []
    for shard in shards:
        file = f"{gaia_dir}/{shard}.fits"
        data = Table.read(file)
        mask = np.zeros(len(data), dtype=bool)
        for i, d in enumerate(data):
            world = galsim.CelestialCoord(
                d['coord_ra']*galsim.radians, d['coord_dec']*galsim.radians)
            pos = wcs.posToImage(world)
            if image.bounds.includes(pos):
                mask[i] = True
        table_list.append(data[mask])

    table = vstack(table_list)

    # Wrap RA towards 0
    ra = table['coord_ra']
    ra[ra>np.pi] -= 2*np.pi
    table['coord_ra'][:] = ra

    ra = table['coord_ra']
    dec = table['coord_dec']
    mag = table['phot_g_mean_flux'].to(u.ABmag).value
    flux = area_ratio*s0['r'] * 10**(-0.4*(mag-24)) * exptime
    
    if fixAngles:
        from ssa.utils import catalog_to_apparent
        ra, dec = catalog_to_apparent(ra, dec, time, observer)
    
    x, y = wcs.radecToxy(ra, dec, 'rad')

    star = galsim.VonKarman(lam=500, r0=0.03, L0=25.0, force_stepk=0.01)  # FWHM ~ 3.0 arcsec
    for x_, y_, ra_, dec_, mag_, flux_ in zip(tqdm.tqdm(x), y, ra, dec, mag, flux):
        local_wcs = wcs.local(galsim.PositionD(x_, y_))
        stamp = (star*flux_).drawImage(wcs=local_wcs, center=(x_, y_), method='phot')
        bounds = stamp.bounds & image.bounds
        image[bounds] += stamp[bounds]

    # Now add streak piecewise in case it bends.
    ncp = galsim.CelestialCoord(0*galsim.degrees, 90*galsim.degrees)
    for ep1, ep2, flux in endpoints:
            
        dist = ep1.distanceTo(ep2)
        p1 = ep1
        nPieces = 10
        for i in range(1, nPieces+1):
            p2 = ep1.greatCirclePoint(ep2, dist*i/nPieces)
            q = p1.angleBetween(ncp, p2)
            pos1 = wcs.toImage(p1)
            pos2 = wcs.toImage(p2)
            x1, y1 = pos1.x, pos1.y
            x2, y2 = pos2.x, pos2.y
            xmid = (x1+x2)/2
            ymid = (y1+y2)/2
            length = np.hypot(x2-x1, y2-y1)
            local_wcs = wcs.local(galsim.PositionD(xmid, ymid))
            box = galsim.Box(1e-12, dist/10/galsim.arcsec, flux=flux/nPieces).rotate(q)
            obj = galsim.Convolve(box, star)
            stamp = obj.drawImage(wcs=local_wcs, center=(xmid, ymid), method='phot')
            bounds = stamp.bounds & image.bounds
            if bounds.area():
                image[bounds] += stamp[bounds]
            p1 = p2

    # Naively apply vignetting
    vigfun = interp1d(
        np.array([0.        , 0.03448276, 0.06896552, 0.10344828, 0.13793103,
                   0.17241379, 0.20689655, 0.24137931, 0.27586207, 0.31034483,
                   0.34482759, 0.37931034, 0.4137931 , 0.44827586, 0.48275862,
                   0.51724138, 0.55172414, 0.5862069 , 0.62068966, 0.65517241,
                   0.68965517, 0.72413793, 0.75862069, 0.79310345, 0.82758621,
                   0.86206897, 0.89655172, 0.93103448, 0.96551724, 1.        ]),  # degrees
        np.array([0.86972549, 0.87259059, 0.87087748, 0.86941706, 0.86827994,
                   0.86648507, 0.86486115, 0.86346019, 0.86065827, 0.85650741,
                   0.85182144, 0.82975165, 0.78771911, 0.73632577, 0.67956641,
                   0.61945143, 0.55733349, 0.5019528 , 0.45624505, 0.41441689,
                   0.37429813, 0.3345807 , 0.29365183, 0.24876162, 0.19516092,
                   0.1444885 , 0.09833854, 0.05772925, 0.02454469, 0.00247119]),  # fraction
        kind='cubic'
    )
    bounds = image.bounds
    xs = np.arange(bounds.xmin, bounds.xmax+1)
    ys = np.arange(bounds.ymin, bounds.ymax+1)
    xs, ys = np.meshgrid(xs, ys)
    rs = np.hypot(xs, ys)
    # Should really use wcs here to use angular distances, 
    # but for now I'll shortcut just scaling with the central pixel scale
    rs *= np.sqrt(wcs.pixelArea(galsim.PositionD(0, 0)))
    image.array[:] *= vigfun(rs/3600)
    
    return image

In [None]:
# bounds = galsim.BoundsI(-512, 511, -600, 599)
bounds = galsim.BoundsI(-1024, 1023, -1200, 1199)


boresight_ra = 0.0  # deg
boresight_dec = 0.0  # deg
boresight_rot = np.deg2rad(45)
world_origin = galsim.CelestialCoord(boresight_ra*galsim.degrees, boresight_dec*galsim.degrees)

pixSize = 25 # micron/pixel
wcs = polyWCS(th, dthdr*pixSize, world_origin, theta=boresight_rot)

background = sbar['r']

endpoints = [
    (galsim.CelestialCoord(-0.2*galsim.degrees, 0.2*galsim.degrees),
     galsim.CelestialCoord(0.2*galsim.degrees, 0.1*galsim.degrees),
     5e5)
]

time = Time.now()
observer = ssa.EarthObserver(170.0, -30.0, 2400.0)

In [None]:
image = makeImage(bounds, wcs, background, endpoints, time, observer, 123)
plt.figure(figsize=(18, 18))
plt.imshow(image.array, vmin=0, vmax=300)
plt.colorbar()
plt.show()

In [None]:
# Very slightly different with aberration correction to the Gaia catalog
image = makeImage(bounds, wcs, background, endpoints, time, observer, 123, fixAngles=True)
plt.figure(figsize=(18, 18))
plt.imshow(image.array, vmin=0, vmax=300)
plt.colorbar()
plt.show()