In [None]:
from pathlib import Path
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
import numpy as np
from hpmoc import PartialUniqSkymap
from hpmoc.healpy import healpy as hp
from hpmoc.plot import *
import hpmoc.plot as hpplot
import hpmoc.utils as ut
from hpmoc.points import PointsTuple, Rgba
from astropy.units import Unit, Quantity
from astropy.table import Table
from astropy.wcs import WCS
from astroquery.skyview import SkyView

DATA = Path(".").absolute().parent/"tests"/"data"

NUNIQ_FITS = DATA/'S200105ae.fits'
NUNIQ_FITS_GZ = DATA/'S200105ae.fits.gz'
BAYESTAR_NEST_FITS_GZ_1024 = DATA/'S200316bj-1-Preliminary.fits.gz'
CWB_NEST_FITS_GZ_128 = DATA/'S200114f-3-Initial.fits.gz'
CWB_RING_FITS_GZ_128 = DATA/'S200129m-3-Initial.fits.gz'
m = PartialUniqSkymap.read(NUNIQ_FITS, strategy='ligo')

In [None]:
get_wcs?

In [None]:
pts2 = PointsTuple(
    points=[
        (179., 16., 2.),
        (62, -14, 3., 'Candidate'),
        (
            (5*Unit('hourangle')+34.5*Unit('arcmin')).to('deg').value,
            (22*Unit('deg')+1*Unit('arcmin')).to('deg').value,
            (6*Unit('arcmin')).to('deg').value,
            'M1'
        )
    ],
    rgba=Rgba(0.6, 0.2, 1, 0.5),
    marker='2',
    label='More points',
)

In [None]:
ax = m.plot(pts2, rot=(70, 5), vdelta=0.3, hdelta=0.3, projection='TAN')

In [None]:
def gridplot(
        *skymaps: Union[
            'hpmoc.PartialUniqSkymap',
            NDArray[Any, Any],
            Tuple[
                NDArray[Any, Any],
                Optional[Union[NDArray[Any, Int], 'astropy.wcs.WCS']],
            ],
        ],
        fig: Optional[
            Union[
                'matplotlib.figure.Figure',
                'matplotlib.gridspec.GridSpec',
                dict,
            ]
        ] = None,
        projections: List[
            Union[
                str,
                'astropy.wcs.WCS',
                'astropy.io.fits.Header',
                List[
                    Union[
                        str,
                        'astropy.wcs.WCS',
                        'astropy.io.fits.Header',
                        'astropy.visualization.wcsaxes.WCSAxes',
                    ]
                ],
            ]
        ]= ('MOL',),
        scatters: Optional[List[Optional[List[PointsTuple]]]] = None,
        # fig args
        subplot_height: float = DEFAULT_GRID_ROW_HEIGHT,
        # fig.add_gridspec args
        ncols: int = DEFAULT_NCOLS,
        hspace: float = DEFAULT_HSPACE,
        wspace: float = DEFAULT_WSPACE,
        wshrink: Union[float, List[float]] = 1.,
        # plotter args
        subplot_kwargs: Optional[List[Optional[List[dict]]]] = None,
        **kwargs
) -> Tuple[
        'matplotlib.gridspec.GridSpec',
        List[List['astropy.visualization.wcsaxes.WCSAxes']]
]:
    """
    Make a grid plot of multiple skymaps (optionally with scatterplots for
    each).

    Parameters
    ----------
    *skymaps : 'hpmoc.PartialUniqSkymap', array, or (array, array)
        The skymaps to plot. Can be a ``PartialUniqSkymap``, a
        single-resolution HEALPix skymap *in NEST ordering only*, or a tuple of
        (pixel values, NUNIQ indices) accepted as the first two arguments to
        ``PartialUniqSkymap``.
    fig: matplotlib.figure.Figure or dict, optional
        The figure to plot to. If not provided, a new figure will be created.
        If a dictonary is provided, it will be passed as keyword arguments to
        create a new figure. If a ``GridSpec`` is provided, then the figure
        to which it is attached will be used, and that ``GridSpec`` will be used
        to define the layout.
    projections : List[Union[str, WCS, fits.Header, List[Union[str, WCS, fits.Header, WCSAxes]]]], optional
        A list of projections (see the ``projection`` argument of ``plot``) to
        use for each skymap in ``skymaps``. If multiple projections are
        specified, they will be plotted alongside each other; if this makes the
        figure too wide, change the number of columns in the grid with
        ``ncols``. You can also pass 
    scatters : List[Optional[List[PointsTuple]]], optional
        Scatterplots to use, one list for each skymap containing the sets of
        points to plot for that skymap. For any of the skymaps which are
        ``PartialUniqSkymap`` instances, you can default to plotting that
        instance's ``point_sources`` by passing ``None`` instead of a list of
        point sources; this is also the default behavior if ``scatters=None``.
    subplot_height : float, optional
        The height of each subplot (in inches). Width is automatically
        determined. Overall figure height will be this times ``ncols``.
        *Ignored if* ``fig`` *is a pre-existing figure, or if* ``figsize`` *is
        specified as a keyword argument in* ``fig`` (though this is not
        recommended usage and should be avoided unless you know what you're
        doing).
    ncols : int, optional
        How many columns of *skymaps* to include in the grid. **NB: All
        subplots for a single skymap are counted as a single column** (in
        contrast to ``matplotlib.gridspec.GridSpec``, which counts each subplot
        in a single column). The number of rows is determined automatically
        from the number of skymaps provided combined with ``ncols``.
    hspace : float, optional
        How much vertical space (height) to reserve between subplots.
    wspace : float, optional
        How much horizontal space (width) to reserve between subplots.
    wshrink : float or list of floats, optional
        Scale the plot widths by this much. Useful if you are adding color
        bars to preserve spacing. If a list, each element corresponds to a
        projection.
    subplot_kwargs : List[Optional[List[dict]]], optional
        Lists of keyword argument dictionaries that will be used for each
        subplot. The first index specifies the plotter, and the second index
        specifies the skymap from ``skymaps``. This behavior allows you to
        easily specify lists of keyword arguments for specific projections
        (since often) only one of the ``projections`` requires skymap-specific
        parameters).  **NB: These subplot-specific keyword arguments take
        precedence over** ``**kwargs`` **for their respective subplots.
        Projection-related keyword arguments are ignored if** ``projections``
        **is a list of lists of axes; see the** ``plot`` **documentation
        for further details.**
    **kwargs
        Keyword arguments applied to all projections; **again, see** ``plot``
        **for usages and caveats.**

    Returns
    -------
    fig : matplotlib.figure.Figure
        A new ``matplotlib`` figure containing the specified subplots.

    Raises
    ------
    ValueError
        If ``scatters`` is provided and is ill-formatted or not of the same
        length as ``skymaps``; if ``subplot_kwargs`` is included and is not of
        the same length as ``projections``, or if its elements are neither
        ``None`` nor lists of the same length as ``skymaps``; if ``fig`` is
        passed as a ``GridSpec`` which is not compatible with the rest of the
        arguments given; or if one of the
        ``projections`` is specified as a string but cannot be found in this
        module.

    See Also
    --------
    plot
    """

    from math import ceil
    from matplotlib.pyplot import figure
    from matplotlib.figure import Figure
    from matplotlib.gridspec import GridSpec
    from astropy.io.fits import Header
    from astropy.wcs import WCS

    gs = fig if isinstance(fig, GridSpec) else None
    fig = fig if gs is None else gs.figure
    nᵖ = len(projections)  # number of projections per skymap
    nᶜ = len(skymaps)       # number of cells, i.e. skymaps
    nˢ = nᶜ*nᵖ          # number of true subplots in fig
    if gs is None:
        nʳ = nᵖ*ncols       # number of true subplots per row
        nrows = int(ceil(nᶜ/ncols))
    else:
        nʳ = gs.ncols
        if nʳ % nᵖ != 0:
            raise ValueError(f"Cannot evenly fit {nᵖ} projections in "
                             f"{nʳ} columns.")
        ncols = int(nʳ / nᵖ)
        nrows = gs.nrows
        if nrows * nʳ < nˢ:
            raise ValueError(f"Not enough rows {nrows} for {nˢ} subplots.")
    scatters = scatters or [[]]*nᶜ
    subplot_kwargs = subplot_kwargs or [None]*nᵖ
    try:
        if len(wshrink) != nᵖ:
            raise ValueError(f"Must provide one wshrink {wshrink} for "
                             f"each projection {projections}")
    except TypeError:
        wshrink = [wshrink]*nᵖ

    if len(scatters) != nᶜ:
        raise ValueError(f"scatters {scatters} must have same len as skymaps "
                         f"{skymaps}")
    if len(subplot_kwargs) != nᵖ:
        raise ValueError(f"subplot_kwargs {subplot_kwargs} must have same "
                         f"len as projections {projections}")

    # initialize kwargs
    for i in range(len(subplot_kwargs)):
        kw = subplot_kwargs[i]
        if kw is None:
            subplot_kwargs[i] = [kwargs]*nᶜ
        elif len(kw) != nᶜ:
            raise ValueError(f"{i}-th element of subplot_kwargs {kw} must "
                             f"have same len as skymaps {skymaps} or else be "
                             "omitted.")
        else:
            for j in range(len(kw)):
                kkw = kw[j]
                kw[j] = kwargs.copy()
                kw[j].update(kkw)
                
    # initialize frame classes and projections
    frame_kwargs = {k: v for k, v in kwargs.items()
                    if k in GET_FRAME_CLASS_KWARG_KEYS}
    proj_kwargs = {k: v for k, v in kwargs.items() if k in GET_WCS_KWARG_KEYS}
    frames = []
    for iᵖ in range(len(projections)):
        p = projections[iᵖ]
        fa = []
        pa = []
        if isinstance(p, (str, WCS, Header)):
            p = [p]*nᶜ
        for iᶜ, pp in enumerate(p):
            kw = kwargs.copy()
            kw.update(subplot_kwargs[iᵖ][iᶜ])
            if isinstance(pp, (str, WCS, Header)):
                fa.append(get_frame_class(
                    pp,
                    **{k: v for k, v in kw.items()
                       if k in GET_FRAME_CLASS_KWARG_KEYS}
                ))
                pa.append(get_projection(
                    pp,
                    **{k: v for k, v in kw.items()
                       if k in GET_WCS_KWARG_KEYS}
                ))
            else:
                fa.append(pp.frame_class)
                pa.append(pp.wcs)
        projections[iᵖ] = pa
        frames.append(fa)

    # from https://matplotlib.org/tutorials/intermediate/gridspec.html
    widths = [w * p[0].pixel_shape[0] / p[0].pixel_shape[1]
              for w, p in zip(wshrink, projections)]
    width_ratios = widths*ncols
    nat_width, nat_height = (subplot_height*sum(width_ratios),
                             subplot_height*nrows)
    if not isinstance(fig, Figure):
        fkw = {
            'figsize': (nat_width, nat_height),
            'facecolor': 'w',
            'edgecolor': 'k',
        }
        fkw.update(**(fig or {}))
        fig = figure(**fkw)
    if gs is None:
        gs = fig.add_gridspec(
            nrows=nrows,
            ncols=nᵖ*ncols,
            width_ratios=width_ratios,
            hspace=hspace,
            wspace=wspace,
            left=0,
            right=1,
            bottom=0,
            top=1,
        )

    axs = []
    for iᵖ, (p, f, k) in enumerate(zip(projections, frames, subplot_kwargs)):
        axr = []
        for iᶜ, s in enumerate(skymaps):
            row, col = divmod(iᶜ*nᵖ+iᵖ, nʳ)
            ax = plot(s, *scatters[iᶜ], projection=p[iᶜ], fig=fig,
                      subplot=gs[row, col], frame_class=f[iᶜ], **k[iᶜ])
            # hide axis labels
            co_ra, co_dec = ax.coords
            co_ra.set_axislabel_visibility_rule('labels')
            co_dec.set_axislabel_visibility_rule('labels')
            #co_ra.set_ticklabel_visible(row + 1 == nrows)
            #co_dec.set_ticklabel_visible(False)
            axr.append(ax)
        axs.append(axr)

    return gs, axs

In [None]:
maps = [m, *[PartialUniqSkymap.read(f, strategy='ligo') for f in
                        (BAYESTAR_NEST_FITS_GZ_1024, CWB_NEST_FITS_GZ_128,
                         CWB_RING_FITS_GZ_128)]]
gs, axs = gridplot(*maps, wspace=0.2, hspace=0.05, wshrink=0.7,
                   scatters=[[pts2]]*4, ncols=1,
                   projections=['CEA', 'TAN'], height=360,
                   subplot_kwargs=[[{'width': 720, 'vmax': s.max().value}
                                    for s in maps],
                                   [{'width': 360, 'vdelta': 0.2,
                                     'hdelta': 0.2, 'cbar': True,
                                     'rot': (70, 5), 'vmax': s.max().value}
                                    for s in maps]],
                   cr=[0.5, 0.9, 0.99], vmin=0., cmap='bone_r',
                   cr_kwargs={'colors': ['aquamarine', 'teal', 'blue']})
for *axx, title in zip(*axs, ('S200105ae', 'S200316bj-1-Preliminary',
                              'S200114f-3-Initial', 'S200129m-3-Initial')):
    for ax in axx:
        ax.set_title(title)
orders = [s.orders(as_skymap=True) for s in maps]
gridplot(*orders, projections=axs, fig=gs, alpha=0.3,
         cmap='rainbow', vmin=0,
         subplot_kwargs=[[{'cbar': {'shrink': 0.5}, 'vmax': o.max()}
                          for o in orders], [{'cmap': None}]*4])

In [None]:
plt.colorbar?

In [None]:
GridSpec?

In [None]:
a = axs[0]

In [None]:
co_ra, co_dec = a.coords

In [None]:
co_ra.set_ticklabel_visible(False)

In [None]:
co_ra.set_axislabel_visibility_rule?

In [None]:
gs.figure.get_size_inches()

In [None]:
fig = plt.figure()
gs1 = fig.add_gridspec(nrows=1, ncols=2, width_ratios=[2, 1])
ax1 = fig.add_subplot(gs1[0, 0], projection=get_wcs('CEA'))
m.plot(ax=ax1)
ax2 = fig.add_subplot(gs1[0, 1], projection=get_wcs('TAN'))
m.plot(ax=ax2)

In [None]:
gs[0]

In [None]:
gs.figure

In [None]:
type(gs[0])

In [None]:
ut.resol2nside(ut.wcs2resol(ax.wcs).to('rad').value, degrees=False)

In [None]:
ut.wcs2mask_and_uniq(ax.wcs)

In [None]:
ax.wcs.pixel_shape[0]/ax.wcs.pixel_shape[1]

In [None]:
fig = plt.gcf()

In [None]:
fig.get_size_inches()

In [None]:
fig = plt.figure()
