# Close approach

<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Setup" data-toc-modified-id="Setup-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Setup</a></span></li><li><span><a href="#Static-plot" data-toc-modified-id="Static-plot-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Static plot</a></span></li><li><span><a href="#Interactive-plot" data-toc-modified-id="Interactive-plot-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Interactive plot</a></span><ul class="toc-item"><li><span><a href="#Define-interaction-methods" data-toc-modified-id="Define-interaction-methods-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>Define interaction methods</a></span></li><li><span><a href="#Run-interactive-plot" data-toc-modified-id="Run-interactive-plot-3.2"><span class="toc-item-num">3.2&nbsp;&nbsp;</span>Run interactive plot</a></span></li></ul></li><li><span><a href="#Looking-at-the-midplane" data-toc-modified-id="Looking-at-the-midplane-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Looking at the midplane</a></span></li></ul></div>

## Setup

In [1]:
import numpy as np
from numpy.linalg import norm
import pandas as pd

# import plotting modules
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import matplotlib.widgets as mw  # get access to the widgets

%matplotlib qt

from galaxy.db import DB
from galaxy.galaxies import Galaxies
from galaxy.galaxy import Galaxy
from galaxy.timecourse import TimeCourse
from galaxy.plots import Plots
from galaxy.approaches import Approaches

from galaxy.utilities import rotation_matrix_to_vector, z_rotation_matrix

In [2]:
import mpl_scatter_density

# Make the norm object to define the image stretch
from astropy.visualization import LogStretch
from astropy.visualization.mpl_normalize import ImageNormalize
lognorm = ImageNormalize(vmin=0., vmax=1000, stretch=LogStretch())

In [3]:
#adjust tick label font size
label_size = 14
matplotlib.rcParams['xtick.labelsize'] = label_size 
matplotlib.rcParams['ytick.labelsize'] = label_size

## Static plot

In [4]:
def plot_density_views(disks, title, xlim=150, ylim=150, pngout=False, fname=None):
    fig = plt.figure(figsize=(10,10))
    fontsize = 16

    # top left
    ax0 = fig.add_subplot(1, 1, 1, projection='scatter_density')
    ax0.scatter_density(disks[0], disks[1], norm=lognorm)

    ax0.set_xlim(-xlim, xlim)
    ax0.set_ylim(-ylim, ylim)

    # Add axis labels (standard pyplot)
    ax0.set_xlabel('x (kpc)', fontsize=fontsize)
    ax0.set_ylabel('y (kpc)', fontsize=fontsize)
    
    # Save file
    if pngout:
        plt.savefig(fname, dpi='figure')
        plt.close();  

In [5]:
snap = 301
app = Approaches(snap=snap, usesql=True)
t = app.time.value / 1000
disks = app.xyz()
title = f"\n\nSnap: {snap:03}\nElapsed time: {t:5.3f} Gyr"

In [6]:
plot_density_views(disks, title, pngout=False, fname=f"approach_{snap:03}.png")
print(snap, end=' ')

301 

  vmin = self._density_vmin(array)
  vmax = self._density_vmax(array)


## Interactive plot

Get the raw data

In [7]:
snap = 300
app = Approaches(snap=snap, usesql=True)
t = app.time.value / 1000
t

4.28571

Center coordinates visually for plotting

In [8]:
# get two (3,N) arrays with just position/velocity coordinates
xyz = app.xyz()
vxyz = app.vxyz()

v = norm(vxyz, axis=0)

# center the collection visually (not CoM)
centroid = np.mean(xyz, axis=1)
xyz -= centroid[:,np.newaxis]
xyz.shape, v.shape

((3, 1166500), (1166500,))

### Define interaction methods

In [9]:
index = None

def callbackRectangle( click, release ): # the events are click and release
    """
    """
    
    global index
    
    # create a rectangle
    width = np.abs(release.xdata - click.xdata)
    height = np.abs(release.ydata - click.ydata)
    
    rect = plt.Rectangle( (click.xdata, click.ydata), width, height,
                            fill=False, color='yellow', lw=1)
    
    # clear old rectangles, add new one
    [p.remove() for p in reversed(ax0.patches)]
    ax0.add_patch(rect)
    
    # extrema will be useful for setting axes
    xmin = min([click.xdata, release.xdata])
    xmax = max([click.xdata, release.xdata])
    ymin = min([click.ydata, release.ydata])
    ymax = max([click.ydata, release.ydata])
    
    index = np.where( (x > xmin) & (x < xmax) & (y > ymin) & (y < ymax) )
    
    # fill in mid and right panels
    add_velocity_plot(ax1, xmin, xmax, ymin, ymax)
    add_origin_plot(ax2, xmin, xmax, ymin, ymax)
    plt.tight_layout()
    
#     ref = ax[1].scatter(x[index], y[index], s=1)

    # Save the file

In [10]:
def onKeyPressed(event):
    
    # eventually want to reset by removing selection
    pass

#     if event.key in ['R', 'r']:
#         ax.set_xlim(-30,30)
#         ax.set_ylim(-30,30)

In [11]:
def add_velocity_plot(ax, xmin, xmax, ymin, ymax):

    sel_inx = index[0]
    sel_xyz = xyz[:,sel_inx]
    sel_vxyz = vxyz[:,sel_inx]
#     print('x: ', len(selected['x']), selected['x'])

    if index is not None:
#         ax1.quiver(x[index], y[index], vx[index], vy[index], color='gray')
        ax1.quiver(sel_xyz[0], sel_xyz[1], sel_vxyz[0], sel_vxyz[1], color='gray')

    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)

    # Add axis labels (standard pyplot)
    ax.set_xlabel('x (kpc)', fontsize=fontsize)
    ax.set_ylabel('y (kpc)', fontsize=fontsize)

In [12]:
def add_origin_plot(ax, xmin, xmax, ymin, ymax):

    if index is not None:
        sel_inx = index[0]
        sel_x, sel_y, sel_z = xyz[:,sel_inx]
        sel_data = app.data[sel_inx]
        sel_galname = sel_data['galname']
        sel_type = sel_data['type']
        
        markersize = 5
        group = np.where((sel_galname=='MW ') & (sel_type==2))
        ax.scatter(sel_x[group], sel_y[group], color='k', s=markersize, label='MW disk')
        
        group = np.where((sel_galname=='M31') & (sel_type==2))
        ax.scatter(sel_x[group], sel_y[group], color='g', s=markersize, label='M31 disk')
        
        group = np.where((sel_galname=='MW ') & (sel_type==3))
        ax.scatter(sel_x[group], sel_y[group], color='r', s=markersize, marker='^', label='MW bulge')
        
        group = np.where((sel_galname=='M31') & (sel_type==3))
        ax.scatter(sel_x[group], sel_y[group], color='b', s=markersize, marker='^', label='M31 bulge')

    ax.set_xlim(xmin, xmax)
    ax.set_ylim(ymin, ymax)

    # Add axis labels (standard pyplot)
    ax.set_xlabel('x (kpc)', fontsize=fontsize)
    ax.set_ylabel('y (kpc)', fontsize=fontsize)
    
    if ax.get_legend() is not None:
        ax.get_legend().remove()
    ax.legend(shadow=True)

### Run interactive plot

In [13]:
x,y,z = xyz
vx,vy,vz = vxyz

fig = plt.figure(figsize=(18,6))
fontsize = 16
titlesize = 20

xlim = ylim = 150

# top left
ax0 = fig.add_subplot(1, 3, 1, projection='scatter_density')
ax0.scatter_density(x, y, norm=lognorm)

ax0.set_xlim(-xlim, xlim)
ax0.set_ylim(-ylim, ylim)

# Add axis labels (standard pyplot)
ax0.set_xlabel('x (kpc)', fontsize=fontsize)
ax0.set_ylabel('y (kpc)', fontsize=fontsize)

ax0.set_title(f"Density at t={t:5.2f} Gyr", fontsize=titlesize)

ax1 = fig.add_subplot(1, 3, 2)
ax1.set_title('Velocities of selected stars', fontsize=titlesize)

ax2 = fig.add_subplot(1, 3, 3)
ax2.set_title('Origins of selected stars', fontsize=titlesize)



rs = mw.RectangleSelector( ax0,                        # the axes to attach to
                           callbackRectangle,         # the callback function
                           drawtype='box',            # draw a box when selecting a region
                           button=[1, 3],             # allow us to use left or right mouse button
                                                      #button 1 is left mouse button
                           minspanx=5, minspany=5,    # don't accept a box of fewer than 5 pixels
                           spancoords='pixels' )      # units for above

# to detect the 'R' key press to reset the image
plt.connect("key_press_event", onKeyPressed);

In [14]:
sel_inx = index[0]

selected = app.data[sel_inx]
print(len(sel_inx))
selected['x']

TypeError: 'NoneType' object is not subscriptable

## Looking at the midplane

In [15]:
tc = TimeCourse()

Get the MW, M31 CoM positions:

In [16]:
snap = 300
MW_com, _ = tc.get_one_com('MW', snap)
M31_com, _ = tc.get_one_com('M31', snap)
MW_com

array([-53.84, 262.48, -15.17], dtype=float32)

Rotate and translate the coordinate system so the CoMs are along the x-axis.

First rotate:

In [17]:
MW_M31_vec = M31_com - MW_com
MW_M31_vec /= norm(MW_M31_vec)
x_hat = np.array([1.,0.,0.])
R = rotation_matrix_to_vector(MW_M31_vec, x_hat)
R

array([[ 0.04675856, -0.9672069 , -0.24964871],
       [ 0.9672069 ,  0.10629901, -0.23067587],
       [ 0.24964871, -0.23067587,  0.94045955]])

In [18]:
MW_com_rot = R @ MW_com
M31_com_rot = R @ M31_com
MW_com_rot, M31_com_rot

(array([-252.60278664,  -20.67370071,  -88.25566336]),
 array([-128.98909469,  -20.67370007,  -88.25566438]))

Next translate:

In [None]:
offset = np.array([(M31_com_rot[0]  + MW_com_rot[0])/2, MW_com_rot[1], MW_com_rot[2]])
offset

In [None]:
xyz = app.xyz()

xyz_rot = R @ xyz
xyz_rot -= offset[:, np.newaxis]

Plot in the new coordinate system:

In [19]:
def midplane_view_coords(raw_xyz, snap):
    """
    Rotate the raw coordinates to put MW, M31 along the x-axis, 
    equally spaced about the origin
    """
    
    # get CoM positions
    tc = TimeCourse()
    MW_com, _ = tc.get_one_com('MW', snap)
    M31_com, _ = tc.get_one_com('M31', snap)

    # calculate rotation matrix
    MW_M31_vec = M31_com - MW_com
    MW_M31_vec /= norm(MW_M31_vec)
    x_hat = np.array([1.,0.,0.])
    R = rotation_matrix_to_vector(MW_M31_vec, x_hat)

    # calculate offset from desired position
    MW_com_rot = R @ MW_com
    M31_com_rot = R @ M31_com
    offset = np.array([(M31_com_rot[0]  + MW_com_rot[0])/2, 
                       MW_com_rot[1], MW_com_rot[2]])
    
    # transform coordinates and return
    xyz_rot = R @ raw_xyz
    xyz_rot -= offset[:, np.newaxis]
    return xyz_rot

In [22]:
snap = 300
app = Approaches(snap=snap, usesql=True)
t = app.time.value / 1000
t

4.28571

In [23]:
x, y, z = midplane_view_coords(app.xyz(), snap)

fig = plt.figure(figsize=(10,6))
fontsize = 16
titlesize = 20

xlim = 120
ylim = 70

# top left
ax0 = fig.add_subplot(1, 1, 1, projection='scatter_density')
ax0.plot([-xlim, xlim], [0,0], color='gray', lw=1)
ax0.plot([0,0], [-ylim, ylim], color='gray', lw=1)
ax0.scatter_density(x, z, norm=lognorm)

ax0.set_xlim(-xlim, xlim)
ax0.set_ylim(-ylim, ylim)

# Add axis labels (standard pyplot)
ax0.set_xlabel('x (kpc)', fontsize=fontsize)
ax0.set_ylabel('z (kpc)', fontsize=fontsize)

ax0.set_title(f"Density at t={t:5.2f} Gyr", fontsize=titlesize)

fname = f'density_rot_{snap}.pdf'
plt.savefig(fname, dpi='figure');

  vmin = self._density_vmin(array)
  vmax = self._density_vmax(array)


Select stars that are roughly equidistant from both CoMs (i.e. near x=0):

In [24]:
bridge_index = np.where( (x > -20) & (x < 30) )
bridge_data = app.data[bridge_index]
len(bridge_data), bridge_data

(2763,
 array([('M31', 2, 2.000e-05, -28.7791, 186.126, -80.1723,  117.818 , -117.963 , -22.2706),
        ('M31', 2, 2.000e-05, -39.6964, 198.329, -75.5905,   89.4208,  -89.1834,  18.0413),
        ('M31', 2, 2.000e-05, -30.214 , 186.801, -81.1952,  116.911 , -106.009 , -47.9533),
        ...,
        ('MW ', 3, 2.001e-05, -83.8462, 216.74 , -49.5789,  134.361 ,  -40.9643,  26.6477),
        ('MW ', 3, 2.001e-05, -69.709 , 227.603, -68.0291,  118.112 ,  -49.8217, -13.2683),
        ('MW ', 3, 2.001e-05, -66.5686, 220.342, -44.1338, -187.922 ,  160.635 , 255.322 )],
       dtype=[('galname', '<U3'), ('type', 'u1'), ('m', '<f4'), ('x', '<f4'), ('y', '<f4'), ('z', '<f4'), ('vx', '<f4'), ('vy', '<f4'), ('vz', '<f4')]))

Get the data into a pandas df for easier analysis:

In [None]:
df = pd.DataFrame()
df['galname'] = bridge_data['galname']
df['type'] = bridge_data['type']
df['x'] = x[bridge_index]
df['y'] = y[bridge_index]
df['z'] = z[bridge_index]

In [None]:
# create some better column names
types = {1: 'Halo', 2: 'Disk', 3: 'Bulge'}
df['origin'] = df['type'].map(types)
df.head()

Show the counts in a pivot table:

In [None]:
df_piv = pd.pivot_table(df, values='x',
        index='galname', columns='origin',
        aggfunc='count', fill_value=0, margins=True)
df_piv

In [None]:
print(df_piv.to_latex(column_format='lrrr'))

In [None]:
import plotly.express as px

In [None]:
counts = df.groupby(['galname', 'origin'])['x'].count()
counts = counts.reset_index().rename(columns={'galname': 'Galaxy', 'x': 'Count'})
counts.head()

In [None]:
px.bar(counts, x='Galaxy', y='Count', color='origin')