In [None]:
from pathlib import Path
from matplotlib import pyplot as plt
import numpy as np
from hpmoc import PartialUniqSkymap
from hpmoc.healpy import healpy as hp
from hpmoc.plot import plot, gridplot
import hpmoc.utils as ut
from astropy.table import Table
from astropy.wcs import WCS
from astroquery.skyview import SkyView
from astropy.coordinates.sky_coordinate import SkyCoord
from astropy.units import Unit, deg, rad
from scipy.interpolate import interp2d
from nptyping import NDArray
from typing import Any, Tuple, Callable, Union, Optional

DATA = Path(".").absolute().parent/"tests"/"data"

NUNIQ_FITS = DATA/'S200105ae.fits'
m = PartialUniqSkymap.read(NUNIQ_FITS, strategy='ligo')

In [None]:
hdu = SkyView.get_images(position='M1', survey='BAT SNR 150-195')[0][0]

In [None]:
mh = PartialUniqSkymap(hdu.data, WCS(hdu.header))

In [None]:
mh.plot(fig={'dpi': 200}, missing_color='blue', nan_color='green',
       width=1440, height=1440, rot=(80, 20), projection=WCS(hdu.header))

In [None]:
w = WCS(hdu.header)
m.plot(projection=w, cr=[0.9], cr_kwargs={'colors': 'blue'})

In [None]:
axh = plt.subplot(1, 1, 1, projection=w)
axh.imshow(hdu.data, cmap='gist_heat_r')

In [None]:
m

In [None]:
m.plot()

In [None]:
ax = m.plot(rot=(70, 5), vdelta=0.3, hdelta=0.3, projection='TAN')

In [None]:
from hpmoc.utils import wcs2ang, resol2nside, wcs2resol, nest2uniq, wcs2nest

wcs = ax.wcs
data = ax.images[0].get_array().data.T

# nearest-neighbor
def interp_wcs_nn(
        wcs: 'astropy.wcs.WCS',
        data: NDArray[(Any,), Any],
) -> Tuple[NDArray[(Any,), int], NDArray[(Any,), float]]:
    """
    Do a nearest-neighbor interpolation of ``data`` with coordinates
    specified by ``wcs`` FITS world coordinate system.

    Parameters
    ----------
    wcs : astropy.wcs.WCS
        The world coordinate system defining pixel locations. If loading
        a FITS file as an HDU called ``hdu``, you can get this argument
        as ``astropy.wcs.WCS(hdu.header)``. *Note that you will need to
        manually include units for dimensionful quantities.*
    data : array-like
        The data corresponding to ``WCS``. Available from an HDU as
        ``hdu.data``.

    Returns
    -------
    u : array
        The corresponding NUNIQ HEALPix indices of the input skymap.
    s : array-like
        The pixel-values of the input skymap interpolated at the locations of
        the pixels in ``u``.
        
    See Also
    --------
    hpmoc.partial.PartialUniqSkymap
    astropy.wcs.WCS
    """
    nside, nest, x, y = wcs2nest(wcs, order_delta=2)
    interp = data[np.round(x).astype(int), np.round(y).astype(int)]
    return nest2uniq(nest, nside), interp


def interp_wcs(
        wcs: 'astropy.wcs.WCS',
        data: NDArray[(Any,), Any],
        interp: Optional[
            Union[
                str,
                Tuple[
                    int,
                    Callable[
                        [
                            NDArray[(Any,), float],
                            NDArray[(Any,), float],
                            NDArray[(Any,), Any]
                        ],
                    NDArray[(Any,), Any]
                    ]
                ],
            ]
        ] = 'nearest'
) -> Tuple[NDArray[(Any,), int], NDArray[(Any,), float]]:
    """
    Interpolate ``data`` with coordinates specified by ``wcs`` FITS
    world coordinate system into a HEALPix NUNIQ skymap.

    Parameters
    ----------
    wcs : astropy.wcs.WCS
        The world coordinate system defining pixel locations. If loading
        a FITS file as an HDU called ``hdu``, you can get this argument
        as ``astropy.wcs.WCS(hdu.header)``. *Note that you will need to
        manually include units for dimensionful quantities.*
    data : array-like
        The data corresponding to ``WCS``. Available from an HDU as
        ``hdu.data``.
    interp : str or (int, func), optional
        The interpolation strategy to use. Can be a string specifying one
        of the following pre-defined strategies:
        
        - "nearest" for nearest-neighbor
        - "bilinear" for bicubic
        
        or else a tuple whose first element is the number of orders by
        which the pixels covering the ``WCS`` should have their resolution
        increased ("nearest" uses a value of 2, "bilinear" a value of 1;
        heuristically, a more sophisticated interpolation scheme can probably
        get away with 1), while the second element is a function taking the
        x, y coordinates of the pixels followed by the pixel values in ``data``
        and returning the interpolated pixel values (which will form the return
        value ``s`` of this function).

    Returns
    -------
    u : array
        The corresponding NUNIQ HEALPix indices of the input skymap.
    s : array-like
        The pixel-values of the input skymap interpolated at the locations of
        the pixels in ``u``.
        
    See Also
    --------
    hpmoc.partial.PartialUniqSkymap
    astropy.wcs.WCS
    """
    if interp == 'nearest':
        return interp_wcs_nn(wcs, data)
    if interp == 'bilinear':
        raise NotImplementedError()
    if isinstance(interp, str):
        raise ValueError(f"Unrecognized interpolation strategy: {interp}")
    nside, nest, x, y = wcs2nest(wcs, order_delta=interp[0])
    return nest2uniq(nest, nside), interp(x, y, data)

In [None]:
PartialUniqSkymap(hdu.data, WCS(hdu.header))

In [None]:
interp_wcs(w, hdu.data)

In [None]:
u, s = interp_wcs_nn(w, hdu.data)

In [None]:
hdu.data.size

In [None]:
WCS(hdu.header)

In [None]:
mh = PartialUniqSkymap(hdu.data[np.round(x).astype(int), np.round(y).astype(int)],
                       nest2uniq(nest, nside))

In [None]:
axw = plt.subplot(1, 2, 1, projection=w)
axw.imshow(hdu.data, cmap='gist_heat_r')
axh = mh.plot(fig=axw.figure, subplot=(1, 2, 2), projection=w)
axh.grid(False)

In [None]:
plt.imshow(data)

In [None]:
ax2 = PartialUniqSkymap(interp, nest2uniq(nest, nside)).plot(nan_color='#0007',
                                                             missing_color='#7777',
                                                             cr=[0.9])
ax2.set_facecolor('blue')

In [None]:
im = ax.images[0]

In [None]:
im.get_array().data.shape

In [None]:
interp2d?

In [None]:
ax.images[0].get_interpolation

In [None]:
len(nest)

In [None]:
wcs2nest(wcs, order_delta=1)

In [None]:
len(x)

In [None]:
(4 * nest.reshape((-1, 1)) + np.arange(4)).ravel()

In [None]:
sum(include)

In [None]:
len(nest)

In [None]:
hp.n

In [None]:
ax.wcs.world_toP

In [None]:
ut.resol2nside(ut.wcs2resol(ax.wcs).to('rad').value, degrees=False) << 1

In [None]:
ut.wcs2mask_and_uniq(ax.wcs)

In [None]:
ax.wcs.pixel_shape[0]/ax.wcs.pixel_shape[1]

In [None]:
fig = plt.gcf()

In [None]:
fig.get_size_inches()

In [None]:
fig = plt.figure()
