In [None]:
# ------------------------------------------------------------------------
#
# TITLE - apo_mock_spatial
# AUTHOR - James Lane
# PROJECT - ges-mass
#
# ------------------------------------------------------------------------
#
# Docstrings and metadata:
'''Create the framework to generate mock APOGEE realizations of the stellar 
halo that can be used to test fitting routines. Investigate spatial sampling
'''

__author__ = "James Lane"

In [None]:
### Imports

## Basic
import numpy as np
import sys, os, pdb, copy, glob, subprocess, warnings, dill as pickle
from astropy import units as apu

## Matplotlib
# import matplotlib as mpl
from matplotlib import pyplot as plt

## galpy
from galpy import orbit
from galpy import potential
# from galpy import actionAngle as aA
from galpy import df
from galpy.util import _rotate_to_arbitrary_vector

## scipy
import scipy.integrate
import scipy.interpolate

## Project-specific
sys.path.insert(0,'../../../src/')
# import sample_project.module as project_module

### Scale parameters
ro = 8.275
vo = 220
zo = 0.0208 # Bennett+ 2019

### Notebook setup
%matplotlib inline
plt.style.use('../../../src/mpl/project.mplstyle') # This must be exactly here
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2

## Plan
The algorithm will work as follows:
1. Define a density function that represents the mass profile to be fit
2. Generate stellar mass sample from IMF equal to the total mass (sets total N stars)
3. Assign 3D positions to samples according to density profile
4. Apply the APOGEE selection function
5. Sample kinematics if required

## Functions

In [None]:
def sample_positions(denspot,n=1,r_min=0.,r_max=np.inf,a=None,b=None,c=None,
                     zvec=None,pa=None,alpha=None,beta=None,gamma=None,
                     return_orbits=False,ro=8.,vo=220.,zo=0.):
    '''sample_positions:
    
    Draw position samples from the density profile. The density profile 
    must be spherically symmetric. Density profile can be modified to be 
    triaxial using b and c, and then rotated using either an axis-angle 
    (zvec and pa) or yaw-pitch-roll (alpha, beta, gamma) methods.
    
    Args:
        denspot (galpy.potential.Potential) - galpy potential representing
            the density profile. Must be spherical
        n (int) - Number of samples to draw
        r_min (float) - 
        r_max (float) - 
        a (float) - Density profile scale radius (optional)
        b (float) - triaxial y/x scale ratio
        c (float) - triaxial z/x scale ratio
        zvec (list) - z-axis to align the new coordinate system
        pa (float) - Rotation about the transformed z-axis
        alpha (float) - Roll rotation about the x-axis 
        beta (float) - Pitch rotation about the transformed y-axis
        gamma (float) - Yaw rotation around twice-transformed z-axis
        return_orbits (bool) - Return the orbits
    
    Returns:
        orbs (galpy.orbit.Orbit) - orbits (optional, otherwise None)
    '''
    
    # Draw radial and angular samples
    r_samples = _sample_r(denspot,n=n,r_min=r_min,r_max=r_max,a=a)
    phi_samples,theta_samples = _sample_position_angles(n=n)
    R_samples = r_samples*np.sin(theta_samples)
    z_samples = r_samples*np.cos(theta_samples)
    
    # Apply triaxial scalings and possibly a rotation
    if b is not None and c is not None:
        x_samples = R_samples*np.cos(phi_samples)
        y_samples = R_samples*np.sin(phi_samples)
        y_samples *= b
        z_samples *= c
        if zvec is not None and pa is not None:
            x_samples,y_samples,z_samples = _transform_zvecpa(x_samples,
                y_samples, z_samples, zvec, pa)
        elif alpha is not None and beta is not None and gamma is not None:
            x_samples,y_samples,z_samples = _transform_alpha_beta_gamma(
                x_samples, y_samples, z_samples, alpha, beta, gamma)
        ##ei
        R_samples = np.sqrt(x_samples**2.+y_samples**2.)
        phi_samples = np.arctan2(y_samples,x_samples)
    ##fi
    
    # Make into orbits
    orbs = orbit.Orbit(vxvv=np.array([R_samples,np.zeros(n),np.zeros(n),
        z_samples,np.zeros(n),phi_samples]).T,ro=ro,vo=vo,zo=zo)
    # self.orbs = orbs
    if return_orbits:
        return orbs
    ##fi
#def

def _sample_r(denspot,n=1,r_min=0.,r_max=np.inf,a=1.):
    '''_sample_r:
    
    Draw radial position samples. Note the function interpolates the normalized 
    iCMF onto the variable xi, defined as:
    
    .. math:: \\xi = \\frac{r/a-1}{r/a+1}
    
    so that xi is in the range [-1,1], which corresponds to an r range of 
    [0,infinity)
    
    Args:
        denspot (galpy.potential.Potential) - galpy potential representing
            the density profile. Must be spherical
        n (int) - Number of samples
        r_min (float) - 
        r_max (float) - 
        a (float) - 
        
    Returns:
        r_samples (np.ndarray) - Radial position samples
    '''
    # First make the icmf interpolator
    icmf_xi_interp = _make_icmf_xi_interpolator(denspot,r_min=r_min,
        r_max=r_max,a=a)
    
    # Now draw samples
    icmf_samples = np.random.uniform(size=int(n))
    xi_samples = icmf_xi_interp(icmf_samples)
    return _xi_to_r(xi_samples,a=a)
#def

def _sample_position_angles(n=1):
    '''_sample_position_angles:
    
    Draw galactocentric, spherical angle samples.
    
    Args:
        n (int) - Number of samples
    
    Returns:
        phi_samples (np.ndarray) - Spherical azimuth
        theta_samples (np.ndarray) - Spherical polar angle
    '''
    phi_samples= np.random.uniform(size=n)*2*np.pi
    theta_samples= np.arccos(1.-2*np.random.uniform(size=n))
    return phi_samples,theta_samples
#def

def _make_icmf_xi_interpolator(denspot,r_min=0.,r_max=np.inf,a=1.):
    '''_make_icmf_xi_interpolator:

    Create the interpolator object which maps the iCMF onto variable xi.
    Note - must use self.xi_to_r() on any output of interpolator
    Note - the function interpolates the normalized CMF onto the variable 
    xi defined as:

    .. math:: \\xi = \\frac{r-1}{r+1}

    so that xi is in the range [-1,1], which corresponds to an r range of 
    [0,infinity)

    Args:

    Returns
        icmf_xi_interpolator
    '''
    xi_min= _r_to_xi(r_min,a=a)
    xi_max= _r_to_xi(r_max,a=a)
    xis= np.arange(xi_min,xi_max,1e-4)
    rs= _xi_to_r(xis,a=a)

    ms = potential.mass(denspot,rs,use_physical=False)
    mnorm = potential.mass(denspot,r_max,use_physical=False)

    if r_min > 0:
        ms -= potential.mass(denspot,r_min,use_physical=False)
        mnorm -= potential.mass(denspot,r_min,use_physical=False)
    ms /= mnorm
        
    # Add total mass point
    if np.isinf(r_max):
        xis= np.append(xis,1)
        ms= np.append(ms,1)
    return scipy.interpolate.InterpolatedUnivariateSpline(ms,xis,k=3)
#def

def _r_to_xi(r,a=1.):
    '''_r_to_xi:
    
    Convert r to the variable xi
    '''
    out= np.divide((r/a-1.),(r/a+1.),where=True^np.isinf(r))
    if np.any(np.isinf(r)):
        if hasattr(r,'__len__'):
            out[np.isinf(r)]= 1.
        else:
            return 1.
    return out
#def

def _xi_to_r(xi,a=1.):
    '''_xi_to_r:
    
    Convert the variable xi to r
    '''
    return a*np.divide(1.+xi,1.-xi)
#def

def _transform_zvecpa(x,y,z,zvec,pa):
    '''_transform_zvecpa:
    
    Transform coordinates using the axis-angle method. First align the
    z-axis of the coordinate system with a vector (zvec) and then rotate 
    about the new z-axis by an angle (pa).
    
    Args:
        x,y,z (array) - Coordinates
        zvec (list) - z-axis to align the new coordinate system
        pa (float) - Rotation about the transformed z-axis
        
    Returns:
        x_rot,y_rot,z_rot (array) - Rotated coordinates 
    '''
    pa_rot = np.array([[ np.cos(pa), np.sin(pa), 0.],
                      [-np.sin(pa), np.cos(pa), 0.],
                      [0.         , 0.        , 1.]])

    zvec /= np.sqrt(np.sum(zvec**2.))
    zvec_rot = _rotate_to_arbitrary_vector(np.array([[0.,0.,1.]]),
                                           zvec,inv=True)[0]
    R = np.dot(pa_rot,zvec_rot)
    
    xyz = np.squeeze(np.dstack([x,y,z]))
    if np.ndim(xyz) == 1:
        xyz_rot = np.dot(R, xyz)
        x_rot,y_rot,z_rot = xyz_rot[0],xyz_rot[1],xyz_rot[2]
    else:
        xyz_rot = np.einsum('ij,aj->ai', R, xyz)
        x_rot,y_rot,z_rot = xyz_rot[:,0],xyz_rot[:,1],xyz_rot[:,2]
    return x_rot,y_rot,z_rot
#def

def _transform_alpha_beta_gamma(x,y,z,alpha,beta,gamma):
    '''_transform_alpha_beta_gamma:
    
    Transform x,y,z coordinates by a yaw-pitch-roll transformation.
    
    Args:
        x,y,z (array) - Coordinates
        alpha (float) - Roll rotation about the x-axis 
        beta (float) - Pitch rotation about the transformed y-axis
        gamma (float) - Yaw rotation around twice-transformed z-axis
        
    Returns:
        x_rot,y_rot,z_rot (array) - Rotated coordinates 
    '''
    # Roll matrix
    Rx = np.zeros([3,3])
    Rx[0,0] = 1
    Rx[1]   = [0           , np.cos(alpha), -np.sin(alpha)]
    Rx[2]   = [0           , np.sin(alpha), np.cos(alpha)]
    # Pitch matrix
    Ry = np.zeros([3,3])
    Ry[0]   = [np.cos(beta), 0            , np.sin(beta)]
    Ry[1,1] = 1
    Ry[2]   = [-np.sin(beta), 0, np.cos(beta)]
    # Yaw matrix
    Rz = np.zeros([3,3])
    Rz[0]   = [np.cos(gamma), -np.sin(gamma), 0]
    Rz[1]   = [np.sin(gamma), np.cos(gamma), 0]
    Rz[2,2] = 1
    R = np.matmul(Rx,np.matmul(Ry,Rz))
    
    xyz = np.squeeze(np.dstack([x,y,z]))
    if np.ndim(xyz) == 1:
        xyz_rot = np.dot(R, xyz)
        x_rot,y_rot,z_rot = xyz_rot[0],xyz_rot[1],xyz_rot[2]
    else:
        xyz_rot = np.einsum('ij,aj->ai', R, xyz)
        x_rot,y_rot,z_rot = xyz_rot[:,0],xyz_rot[:,1],xyz_rot[:,2]
    return x_rot,y_rot,z_rot
#def

In [None]:
rc = 30./ro
alpha = 3.5
denspot = potential.PowerSphericalPotentialwCutoff(amp=1.,
    r1=1.,alpha=alpha,rc=rc)
potential.turn_physical_off(denspot)
r_min = 1./ro
r_max = 50./ro
a = denspot._scale
n = int(1e6)
rs = _sample_r(denspot,n=n,r_min=1./8.,r_max=50./8.,a=a)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)

ax.hist(np.log10(rs), bins=20, 
        range=(np.log10(r_min),np.log10(r_max)), 
        histtype='step')

ax.axvline(np.log10(rc), color='Red', linestyle='dashed')
ax.set_yscale('log')
ax.set_xlabel('log r')
ax.set_ylabel('N')

In [None]:
cml_mass_sampled = np.arange(0,n)/n
cml_mass_true = potential.mass(denspot,np.sort(rs),use_physical=False)
cml_mass_true -= cml_mass_true[0]
cml_mass_true /= cml_mass_true[-1]

fig = plt.figure()
ax = fig.add_subplot(111)

ax.plot(np.log10(np.sort(rs)), cml_mass_sampled, color='DodgerBlue')
ax.plot(np.log10(np.sort(rs)), cml_mass_true, color='Red',
        linestyle='dashed')

ax.set_xlabel(r'$\log r$')
ax.set_ylabel(r'mass')

### Now handle triaxiality

In order to generate a triaxial profile we sample from a spherical profile:

$r = \sqrt{x_{0}^2+y_{0}^2+z_{0}^2}$

We need to convert this to the coordinate:

$m = \sqrt{x_{T}^2+y_{T}^2/b^2+z_{T}^2/c^2}$

We can therefore get the triaxial profile by multiplying the $y$ coordinate by $b^2/b^2$:

$y_{0}^2 b^2/b^2 = y_{T}^2 / b^2$

and so

$y_{T} = y_{0} b$

Similarly:

$z_{T} = z_{0} c$

So just multiply the $y$ and $z$ coordinate by the scale parameters $b$ and $c$

### Test the creation of orbits

In [None]:
b=0.75
c=2.
orbs = sample_positions(denspot,n=int(1e7),r_min=1./8.,r_max=50./8.,a=a,
                        b=b,c=c,return_orbits=True,ro=ro,vo=vo,zo=zo)

In [None]:
tick_locs = np.arange(-10,12.5,2.5)

fig = plt.figure(figsize=(10,10))
ax1 = fig.add_subplot(221,aspect='equal')
ax2 = fig.add_subplot(222,aspect='equal')
ax3 = fig.add_subplot(223,aspect='equal')
ax4 = fig.add_subplot(224,aspect='equal')

xy_hist,xedge,yedge = np.histogram2d(orbs.x().value,orbs.y().value,bins=51,
                                  range=[[-10,10],[-10,10]])
xdelta = np.diff(xedge)
ydelta = np.diff(yedge)
xcents,ycents = np.meshgrid(xedge[:-1]+xdelta, yedge[:-1]+ydelta)

xz_hist,xedge,zedge = np.histogram2d(orbs.x().value,orbs.z().value,bins=51,
                                  range=[[-10,10],[-10,10]])
xdelta = np.diff(xedge)
zdelta = np.diff(zedge)
xcents,zcents = np.meshgrid(xedge[:-1]+xdelta, zedge[:-1]+zdelta)


ax1.imshow(np.rot90(np.log10(xy_hist)), cmap='rainbow', extent=(-10,10,-10,10))
ax1.set_xlabel('X [kpc]')
ax1.set_ylabel('Y [kpc]')

ax2.imshow(np.rot90(np.log10(xz_hist)), cmap='rainbow', extent=(-10,10,-10,10))
ax2.set_xlabel('X [kpc]')
ax2.set_ylabel('Z [kpc]')

ax3.contour(xcents, ycents, np.rot90(np.log10(xy_hist)), 
            levels=np.arange(2,6,0.5), colors='Black')
ax3.set_xlabel('X [kpc]')
ax3.set_ylabel('Y [kpc]')
ax3.set_xticks(tick_locs)
ax3.set_xticks(tick_locs)
ax3.grid(True)

ax4.contour(xcents, zcents, np.rot90(np.log10(xz_hist)), 
            levels=np.arange(2,6,0.5), colors='Black')
ax4.set_xlabel('X [kpc]')
ax4.set_ylabel('Z [kpc]')
ax4.set_xticks(tick_locs)
ax4.set_xticks(tick_locs)
ax4.grid(True)

### Try rotation

In [None]:
b=0.66
c=0.42
zvec = np.array([0.,0.8,0.4])
pa = np.pi/3
alpha = 0.
beta = 0.
gamma = 0.
orbs = sample_positions(denspot,n=int(1e6),r_min=1./8.,r_max=50./8.,a=a,
                        b=b,c=c,zvec=zvec,pa=pa,return_orbits=True,
                        ro=ro,vo=vo,zo=zo)

In [None]:
tick_locs = np.arange(-10,12.5,2.5)

fig = plt.figure(figsize=(15,10))
ax1 = fig.add_subplot(231,aspect='equal')
ax2 = fig.add_subplot(232,aspect='equal')
ax3 = fig.add_subplot(233,aspect='equal')
# ax4 = fig.add_subplot(234,aspect='equal')
# ax5 = fig.add_subplot(235,aspect='equal')
# ax6 = fig.add_subplot(236,aspect='equal')

xy_hist,xy_xedge,xy_yedge = np.histogram2d(orbs.x().value,orbs.y().value,bins=51,
                                  range=[[-10,10],[-10,10]])
xy_xdelta = np.diff(xy_xedge)
xy_ydelta = np.diff(xy_yedge)
xy_xcents,xy_ycents = np.meshgrid(xy_xedge[:-1]+xy_xdelta,
                                  xy_yedge[:-1]+xy_ydelta)

xz_hist,xz_xedge,xz_zedge = np.histogram2d(orbs.x().value,orbs.z().value,bins=51,
                                  range=[[-10,10],[-10,10]])
xz_xdelta = np.diff(xz_xedge)
xz_zdelta = np.diff(xz_zedge)
xz_xcents,xz_zcents = np.meshgrid(xz_xedge[:-1]+xz_xdelta,
                                  xz_zedge[:-1]+xz_zdelta)

yz_hist,yz_yedge,yz_zedge = np.histogram2d(orbs.y().value,orbs.z().value,bins=51,
                                  range=[[-10,10],[-10,10]])
yz_ydelta = np.diff(yz_yedge)
yz_zdelta = np.diff(yz_zedge)
yz_ycents,yz_zcents = np.meshgrid(yz_yedge[:-1]+yz_ydelta,
                                  yz_zedge[:-1]+yz_zdelta)


ax1.imshow(np.rot90(np.log10(xy_hist)), cmap='rainbow', extent=(-10,10,-10,10))
ax1.set_xlabel('X [kpc]')
ax1.set_ylabel('Y [kpc]')

ax2.imshow(np.rot90(np.log10(xz_hist)), cmap='rainbow', extent=(-10,10,-10,10))
ax2.set_xlabel('X [kpc]')
ax2.set_ylabel('Z [kpc]')

ax3.imshow(np.rot90(np.log10(yz_hist)), cmap='rainbow', extent=(-10,10,-10,10))
ax3.set_xlabel('Y [kpc]')
ax3.set_ylabel('Z [kpc]')

# ax4.contour(xy_xcents, xy_ycents, np.rot90(np.log10(xy_hist)), 
#             levels=np.arange(2,6,0.5), colors='Black')
# ax4.set_xlabel('X [kpc]')
# ax4.set_ylabel('Y [kpc]')
# ax4.set_xticks(tick_locs)
# ax4.set_xticks(tick_locs)
# ax4.grid(True)

# ax5.contour(xz_xcents, xz_zcents, np.rot90(np.log10(xz_hist)), 
#             levels=np.arange(2,6,0.5), colors='Black')
# ax5.set_xlabel('X [kpc]')
# ax5.set_ylabel('Z [kpc]')
# ax5.set_xticks(tick_locs)
# ax5.set_xticks(tick_locs)
# ax5.grid(True)

# ax6.contour(yz_ycents, yz_zcents, np.rot90(np.log10(yz_hist)), 
#             levels=np.arange(2,6,0.5), colors='Black')
# ax6.set_xlabel('Y [kpc]')
# ax6.set_ylabel('Z [kpc]')
# ax6.set_xticks(tick_locs)
# ax6.set_xticks(tick_locs)
# ax6.grid(True)

fig.tight_layout()
fig.show()

### Save some sample orbits

In [None]:
b=0.5
c=2.
zvec = np.array([0.,0.,1.])
pa = np.pi/4.
alpha = 0.
beta = 0.
gamma = 0.
orbs = sample_positions(denspot,n=int(1e8),r_min=2./8.,r_max=50./8.,a=a,
                        b=b,c=c,zvec=zvec,pa=pa,return_orbits=True,
                        ro=ro,vo=vo,zo=zo)

In [None]:
data_dir = '/geir_data/scr/lane/projects/ges-mass/apo_mocks/'
with open(data_dir+'orbs_1e6.pkl','wb') as f:
    pickle.dump(orbs[:int(1e6)],f)
##wi
with open(data_dir+'orbs_1e7.pkl','wb') as f:
    pickle.dump(orbs[:int(1e7)],f)
##wi
with open(data_dir+'orbs_1e8.pkl','wb') as f:
    pickle.dump(orbs[:int(1e8)],f,protocol=4)
##wi