In [None]:
import mpl_toolkits
import matplotlib
import matplotlib.pyplot as plt
import astropy.units as u
from astropy.io.ascii import read
from pathlib import Path
import numpy as np

In [None]:
plt.style.use('seaborn-paper')
plt.style.use('seaborn-colorblind')

In [None]:
def groupby_alt(pointings):
    pp = pointings.to_pandas()
    return pp.groupby('alt').groups

In [None]:
def plot_pointings(pointings, ax=None, projection='polar', add_grid3d=False, **kwargs):
    """
    Produce a scatter plot of the pointings.
    Copied from lstmcpipe==0.9.0:
    https://github.com/cta-observatory/lstmcpipe/blob/v0.9.0/lstmcpipe/plots/pointings.py
    Parameters
    ----------
    pointings: 2D array of `astropy.quantities` or numpy array in rad
    ax : `matplotlib.pyplot.Axis`
    projection: str or None
        '3d' | 'aitoff' | 'hammer' | 'lambert' | 'mollweide' | 'polar' | 'rectilinear'
    add_grid3d: bool
        add a 3D grid in case of projection='3d'
    kwargs: dict
        kwargs for `matplotlib.pyplot.scatter`
    Returns
    -------
    ax: `matplotlib.pyplot.axis`
    """
    
    if ax and projection:
        if not isinstance(ax, matplotlib.projections.get_projection_class(projection)):
            raise ValueError(f"ax of type {type(ax)} and projection {projection} are exclusive")
        
    if ax is None:
        fig = plt.gcf()
        ax = fig.add_subplot(111, projection=projection)
    
    elif isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D):
        projection = '3d'
    elif isinstance(ax, matplotlib.projections.polar.PolarAxes):
        projection = 'polar'
        
    if projection == '3d':
        r = 1.
        if add_grid3d:
            az, alt = np.mgrid[0:2.01*np.pi:(1/10)* np.pi/2., 0:np.pi/2.:(1/10)* np.pi/4.]
            X = r * np.cos(az) * np.cos(alt)
            Y = r * np.sin(az) * np.cos(alt)
            Z = r * np.sin(alt)
            ax.plot_surface(X, Y, Z, cmap=plt.cm.YlGnBu_r, alpha=.1)
        r *= 1.01
        az = pointings[:,0]
        alt = pointings[:,1]
        X = r * np.cos(az) * np.cos(alt)
        Y = r * np.sin(az) * np.cos(alt)
        Z = r * np.sin(alt)

        ax.scatter(X, Y, Z, **kwargs)

        box_ratio=1.03
        ax.set_xlim3d([-r*box_ratio, r*box_ratio])
        ax.set_ylim3d([-r*box_ratio, r*box_ratio])
        ax.set_zlim3d([-r*box_ratio, r*box_ratio])
        
    elif projection == 'polar':
        ax.scatter(pointings[:, 0], np.pi/2.*u.rad - pointings[:, 1], **kwargs)
        ax.set_xlabel('Azimuth')
        rticks_deg = [10, 30, 50, 70, 90]
        ax.set_rticks(np.deg2rad(rticks_deg), labels=[f'{r}°' for r in rticks_deg])
        ax.set_rmax(np.pi/2.)
        ax.set_rlabel_position(20)
        
        
    else:
        ax.scatter(pointings[:, 0], pointings[:, 1], **kwargs)
        ax.set_xlabel('Azimuth')
        ax.set_ylabel('Altitude')
        
        
    ax.legend()
    ax.grid(True)
    ax.set_axisbelow(True)

    return ax

In [None]:
pointings_test = read('pointings_test.ecsv')

In [None]:
pointings_test

In [None]:
pointings_train = read('pointings_train.ecsv', format='ecsv')
pointings_train = np.transpose([pointings_train['az'].to(u.rad), pointings_train['alt'].to(u.rad)])*u.rad
pointings_train

In [None]:
plt.figure(figsize=(5,5))

ax = plot_pointings(pointings_train,
                    label='Training nodes', color='black', s=8)


grp = groupby_alt(pointings_test)

for ii, (k, index) in enumerate(grp.items()):
    index=list(index)
    ax.scatter(pointings_test[index]['az'].to_value(u.rad), np.pi/2. - pointings_test[index]['alt'].to_value(u.rad), 
               marker='*',
               label=f"Testing nodes zd={90*u.deg-pointings_test[index[0]]['alt']:.2f}",
               s=100
              )

rticks = [10, 20, 30, 40, 50, 60]
ax.set_rticks(np.deg2rad(rticks), [f'{r:d}°' for r in rticks])
ax.text(np.radians(ax.get_rlabel_position()+10),2*ax.get_rmax()/3.,'Zenith',
        rotation=ax.get_rlabel_position(),ha='center',va='center')


ax.vlines(np.deg2rad(175.158), 0, 0.99*ax.get_rmax(),
          ls='dotted', 
          color='grey', lw=1, zorder=0,
          label='magnetic North-South'
         )
ax.vlines(np.deg2rad(175.158)+np.pi, 0, 0.99*ax.get_rmax(), ls='dotted', color='grey', lw=1, zorder=0)


ax.legend(fontsize=8, loc='lower right', bbox_to_anchor=(1.1, 0.16))




plt.tight_layout()
plt.savefig(Path('.', 'pointings_per_alt.png'), dpi=250)
plt.show()