In [None]:
# Third-party
from astropy.io import fits
import astropy.coordinates as coord
import astropy.units as u
from astropy.table import Table
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
from scipy.special import logsumexp
from scipy.optimize import minimize

from pyia import GaiaData

In [None]:
gc_frame = coord.Galactocentric(z_sun=0*u.pc, galcen_distance=8.1*u.kpc)

In [None]:
hex_h = 200
filename = '../data/rv-good-plx.fits'
g = GaiaData(filename)

c = g.skycoord
galcen = c.transform_to(gc_frame)

In [None]:
gal = c.galactic
gal.set_representation_cls('cartesian')

In [None]:
cyl = gal.transform_to(gc_frame)
cyl.set_representation_cls('cylindrical')

xyz = np.vstack((gal.u.to(u.pc).value, 
                 gal.v.to(u.pc).value,
                 gal.w.to(u.pc).value)).T

UVW = np.vstack((gal.U.to(u.km/u.s).value, 
                 gal.V.to(u.km/u.s).value, 
                 gal.W.to(u.km/u.s).value)).T

# UVW = np.vstack((cyl.d_rho.to(u.km/u.s).value, 
#                  - ((cyl.rho * cyl.d_phi).to(u.km/u.s, u.dimensionless_angles()).value + 220.), 
#                  cyl.d_z.to(u.km/u.s).value)).T

disk_vmask = np.linalg.norm(UVW, axis=1) < 120
xyz = xyz[disk_vmask]
UVW = UVW[disk_vmask]

---

## Hexagons

In [None]:
def get_hexagons(h):
    a = np.sqrt(3)/2 * h # inner side
    
    pas = dict() # keyed by "ring"
    pas[0] = list()
    pas[1] = list()
    
    pa0 = mpl.patches.RegularPolygon((0., 0.), numVertices=6, 
                                     radius=h, orientation=np.pi/2)
    pas[0].append(pa0.get_verts())

    for ang in np.arange(0, 360, 60)*u.deg:
        # Ring 1
        xy0 = 2*a * np.array([np.cos(ang+90*u.deg), np.sin(ang+90*u.deg)])
        pa = mpl.patches.RegularPolygon(xy0, numVertices=6, 
                                        radius=h, orientation=np.pi/2)
        pas[1].append(pa.get_verts())
        
    return pas

In [None]:
hexs = get_hexagons(hex_h)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 6))

for k in hexs.keys():
    for pa in hexs[k]:
        pa = mpl.patches.Polygon(pa, facecolor='none', 
                                 edgecolor='tab:green', 
                                 zorder=100)
        ax.add_patch(pa)

ax.plot(gal.u, gal.v, 
        marker=',', alpha=0.1, color='k',
        linestyle='none', zorder=10)
        
lim = 1000
ax.set_xlim(-lim, lim)
ax.set_ylim(-lim, lim)

ax.set_xlabel('$x$ [pc]')
ax.set_ylabel('$y$ [pc]')

---

In [None]:
# vmax = 1400
# for k in hexs.keys():
#     for b, pa in enumerate(hexs[k]):
#         hex_mask = mpl.patches.Path(pa).contains_points(xyz[:, :2])
#         print(hex_mask.sum())

#         lim = 150
#         bins = np.linspace(-lim, lim, 101)

#         fig, axes = plt.subplots(2, 2, figsize=(8, 7.2))
        
#         for a, (i, j) in zip([0, 2, 3], [(0, 1), (0, 2), (1, 2)]):
#             ax = axes.flat[a]
#             H, xe, ye = np.histogram2d(UVW[hex_mask,i], UVW[hex_mask,j], bins=bins)
#             ax.pcolormesh(xe, ye, H.T, 
#                           norm=mpl.colors.LogNorm(), 
#                           cmap='magma', vmin=1, vmax=vmax)
#             ax.set_xlim(-lim, lim)
#             ax.set_ylim(-lim, lim)
        
#         axes[0, 0].set_ylabel('$-v_y-220$')
#         axes[1, 0].set_ylabel('$v_z$')
#         axes[1, 0].set_xlabel('$v_x$')
#         axes[1, 1].set_xlabel('$-v_y-220$')
#         axes[0, 0].xaxis.set_ticklabels([])
#         axes[1, 1].yaxis.set_ticklabels([])
        
#         # axes[0,1].set_visible(False)
#         ax = axes[0,1]
#         for k_ in hexs.keys():
#             for pa_ in hexs[k_]:
#                 pa_ = mpl.patches.Polygon(pa_, facecolor='none', edgecolor='#333333')
#                 ax.add_patch(pa_)
                
#         ax.add_patch(mpl.patches.Polygon(pa, facecolor='#333333', edgecolor='#333333'))
#         ax.set_xlim(-750, 750)
#         ax.set_ylim(-750, 750)
#         ax.set_xlabel('$x$ [pc]')
#         ax.set_ylabel('$y$ [pc]')

#         fig.tight_layout()
#         fig.savefig('../plots/uvw/big-hex-{2}-{0}-{1:02d}.png'.format(k, b, 'all'), dpi=250)
#         plt.close(fig)

In [None]:
_cyl = gal.represent_as('cylindrical')[disk_vmask]
mask_r100 = (_cyl.rho < 100*u.pc) & (np.abs(_cyl.z) < 250*u.pc)
mask_r300 = (_cyl.rho < 300*u.pc) & (np.abs(_cyl.z) < 500*u.pc)
mask_r500 = (_cyl.rho < 500*u.pc) & (np.abs(_cyl.z) < 500*u.pc)
mask_r100.sum(), mask_r500.sum()

In [None]:
def kde_obj(V_test, V_control, bw):
    """Note: this is barely tested...by eye
    
    X_test : (n, 3)
    X_control : (m, 3)
    """
    delta = V_test[None] - V_control[:, None] # (m, n, 3)
    
    # objective function
    arg = np.sum(-0.5 * (delta / bw)**2, axis=-1) # (m, n)
    scalar = logsumexp(arg, axis=0).sum()
    
    # gradient 
    _arg = arg - arg.max(axis=0)[None] # (m, n)
    grad = np.sum(-delta / bw**2 * np.exp(_arg)[..., None], axis=0) / np.sum(np.exp(_arg), axis=0)[..., None]
    
    return scalar, grad

In [None]:
control_v = UVW[mask_r100][::24]
test_v = UVW[mask_r300][::128]
test_x = xyz[mask_r300][::128]

control_v.shape, test_v.shape

In [None]:
_s, _g = kde_obj(test_v, control_v, bw=5.)

In [None]:
# test_v2 = test_v.copy()
# fudge = .7
# test_v2[612, 1] += fudge

# _s2, _g2 = kde_obj(test_v2, anchor_v, bw=5.)

In [None]:
# (_s2 - _s) / fudge # check numerical derivative

In [None]:
# _g[612, 1]

In [None]:
def f_and_grad(p, V_test, dX_test, V_control, bw):
    aij = np.array(p).reshape(3, 3)
    s, g = kde_obj(V_test + dX_test.dot(aij), 
                   V_control, bw)
    grad = (g[:, None] * dX_test[:, :, None]).sum(axis=0).reshape((9,))
    return -s, -grad

In [None]:
p = 1e-3 * np.random.random(size=9)
_s, _g = f_and_grad(p, test_v, test_x, control_v, bw=5.)

In [None]:
p2 = p.copy()
p2[4] += 1e-5
_s2, __ = f_and_grad(p2, test_v, test_x, control_v, bw=5.)

In [None]:
(_s2 - _s) / 1e-5

In [None]:
_g[4]

In [None]:
res = minimize(f_and_grad, x0=np.zeros(9), 
               method='L-BFGS-B', jac=True, 
               args=(test_v, test_x, control_v, 3.))

In [None]:
res

In [None]:
res.x

In [None]:
pred_v = test_v + test_x.dot(res.x.reshape(3, 3))
pred_v.shape

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5),
                         sharex=True, sharey=True)

axes[0].plot(control_v[:, 0], control_v[:, 1], 
             marker='.', alpha=0.1, ls='none')

# axes[1].plot(test_v[:, 0], test_v[:, 1], 
#              marker='.', alpha=0.1, ls='none')

# axes[2].plot(pred_v[:, 0], pred_v[:, 1], 
#              marker='.', alpha=0.1, ls='none')

rand_idx = np.random.choice(mask_r500.sum(), size=10000, replace=False)
_v = UVW[mask_r500][rand_idx]
_x = xyz[mask_r500][rand_idx]
_v_adj = _v + _x.dot(res.x.reshape(3, 3))
axes[1].plot(_v[:, 0], _v[:, 1], 
             marker='.', alpha=0.1, ls='none')

axes[2].plot(_v_adj[:, 0], _v_adj[:, 1], 
             marker='.', alpha=0.1, ls='none')

axes[0].set_xlim(-150, 150)
axes[0].set_ylim(-150, 150)

In [None]:
dv = _v.dot(res.x.reshape(3, 3))
dvmag = np.linalg.norm(dv, axis=-1)
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.scatter(_v_adj[:, 0], _v_adj[:, 1], 
           # c=dv[:, 1], 
           c=dvmag,
           cmap='viridis', vmin=0, vmax=2,
           marker='.', linewidth=0, alpha=0.2)

In [None]:
vmax = 1400
for name in ['data', 'adjusted']:
    for k in hexs.keys():
        for b, pa in enumerate(hexs[k]):
            hex_mask = mpl.patches.Path(pa).contains_points(xyz[:, :2])
            print(hex_mask.sum())

            this_x = xyz[hex_mask]
            this_v = UVW[hex_mask]
            if name == 'adjusted':
                fix_v = this_v + this_x.dot(res.x.reshape(3, 3))
                the_v = fix_v
            else:
                the_v = this_v

            lim = 150
            bins = np.linspace(-lim, lim, 101)

            fig, axes = plt.subplots(2, 2, figsize=(8, 7.2))

            for a, (i, j) in zip([0, 2, 3], [(0, 1), (0, 2), (1, 2)]):
                ax = axes.flat[a]
                H, xe, ye = np.histogram2d(the_v[:,i], the_v[:,j], bins=bins)
                ax.pcolormesh(xe, ye, H.T, 
                              norm=mpl.colors.LogNorm(), 
                              cmap='magma', vmin=1, vmax=vmax)
                ax.set_xlim(-lim, lim)
                ax.set_ylim(-lim, lim)

            axes[0, 0].set_ylabel('$-v_y-220$')
            axes[1, 0].set_ylabel('$v_z$')
            axes[1, 0].set_xlabel('$v_x$')
            axes[1, 1].set_xlabel('$-v_y-220$')
            axes[0, 0].xaxis.set_ticklabels([])
            axes[1, 1].yaxis.set_ticklabels([])

            # axes[0,1].set_visible(False)
            ax = axes[0,1]
            for k_ in hexs.keys():
                for pa_ in hexs[k_]:
                    pa_ = mpl.patches.Polygon(pa_, facecolor='none', edgecolor='#333333')
                    ax.add_patch(pa_)

            ax.add_patch(mpl.patches.Polygon(pa, facecolor='#333333', edgecolor='#333333'))
            ax.set_xlim(-750, 750)
            ax.set_ylim(-750, 750)
            ax.set_xlabel('$x$ [pc]')
            ax.set_ylabel('$y$ [pc]')

            fig.tight_layout()
            fig.savefig('../plots/uvw/bighex-{2}-{0}-{1:02d}.png'.format(k, b, name), dpi=250)
            plt.close(fig)

```
convert -delay 30 -loop 1 bighex-data-1-*.png bighex-data-ring.gif
convert -delay 30 -loop 1 bighex-adjusted-1-*.png bighex-adjusted-ring.gif
```