In [None]:
import gala.potential as gp
import gala.integrate as gi
import gala.dynamics as gd
import astropy.units as auni
import astropy.coordinates as acoo
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import binned_statistic_2d
pot = gp.MilkyWayPotential2022()
kms = auni.km / auni.s

acoo.galactocentric_frame_defaults.set('latest')



def doit(vej, cosang, t):
    sinang = np.sqrt(1 - cosang**2)
    # start from 10 pc radius
    startpos = np.array([cosang, 0, sinang]) * 0.01 * auni.kpc
    vel = np.array([cosang, 0, sinang]) * vej
    w0 = gd.PhaseSpacePosition(startpos, vel=vel * kms)
    nsteps = 10000
    timestep = t * auni.Myr / nsteps
    orbit = gp.Hamiltonian(pot).integrate_orbit(
        w0,
        dt=timestep,
        n_steps=nsteps,
        Integrator=gi.DOPRI853Integrator,
        # Integrator_kwargs = dict(atol=1e-15,rtol=1e-15)
    )
    R = (orbit.x**2 + orbit.z**2 + orbit.y**2)**.5
    z = orbit.z.to_value(auni.kpc)
    VR = orbit.v_x * R / orbit.x
    Vz = orbit.v_z - orbit.v_x * orbit.z / orbit.x
    return R.to_value(auni.kpc), z, VR.to_value(kms), Vz.to_value(kms)


def doall(N=10000, seed=3):
    rng = np.random.default_rng(seed)
    vej = 10**rng.uniform(2.8, 3.5, size=N)
    cosa = rng.uniform(0, 1, size=N)
    times = 100
    r0, r1, r2, r3, r4 = [], [], [], [], []
    for curv, curc in zip(vej, cosa):
        R, z, VR, Vz = doit(curv, curc, times)
        r0.append(np.zeros(len(R)) + curv)
        r1.append(R)
        r2.append(z)
        r3.append(VR)
        r4.append(Vz)

    return [np.concatenate(_) for _ in [r0, r1, r2, r3, r4]]


R, z, VR, Vz = doall()



KeyboardInterrupt: 

In [None]:


# Choose how many bins you want in each dimension
xbins = 50  # bins in z
ybins = 50  # bins in log10(VR)

# Flatten to go from (10000, 1000) to (10000000,)
zf = z.ravel()
VRf = VR.ravel()
Vzf = Vz.ravel()
Rf = R.ravel()



# Use binned_statistic_2d to compute the mean ratio in each (z, logVR) bin
# statistic='mean' can also be 'median', 'count', 'sum', etc.
stat, xedges, yedges, binnum = binned_statistic_2d(
    y= zf/Rf, # to avoid log(0) we add a small number
    x=np.log10(Rf),
    values=Vzf/Rf*VRf,
    statistic='mean',
    bins=[xbins, ybins]
)

# 'stat' is a 2D array of the same shape as bins: (xbins, ybins)
# We can plot it with pcolormesh (or imshow).
plt.figure(figsize=(8, 6))
# Note that stat is indexed as stat[ix, iy], but pcolormesh expects
# xedges (length xbins+1) along the horizontal axis, yedges (length ybins+1)
# along the vertical axis, so we transpose 'stat' in pcolormesh.
plt.pcolormesh(xedges, yedges, stat.T, cmap='jet', shading='auto')#, vmin= 0, vmax= 2e3)

# Colorbar labeling
cb = plt.colorbar()
cb.set_label(r'$V_z/R*V_R$ ')

plt.xlabel(r'$\log_{10}(R)$')
plt.ylabel(r'$z/R$')
#plt.title(r'')
plt.tight_layout()
#plt.xlim(
# -2, 2)
plt.show()

1.5

In [3]:
# construct a regular grid interpolator that will give Vz/R*VR for any R, z
import numpy as np
from scipy.interpolate import RegularGridInterpolator

xcenters = 0.5*(xedges[:-1] + xedges[1:])
ycenters = 0.5*(yedges[:-1] + yedges[1:])

interp_func = RegularGridInterpolator(
    (xcenters, ycenters),   # The 2D grid coordinates
    stat,                   # The mean ratio table on that grid
    method='linear',        # or 'nearest', 'cubic', etc.
    bounds_error=False,     # If False, points outside will not raise an error...
    fill_value=np.nan       # ...and will return NaN. (You can choose 0 or None, etc.)
)

In [7]:
# see how long it takes to evalute the interpolator for 1000 points
logR_new = np.linspace(0.1, 2.3, 1000)
zr_new   = np.linspace(0.1, 0.83, 1000)

import time
t0 = time.time()
val = interp_func((logR_new, zr_new))
print("Interpolation time:", (time.time() - t0)/1000, "seconds")

Interpolation time: 1.6624927520751953e-06 seconds


In [25]:
import pickle
import sys
import os
# set current directory to /app/data so we can work with relative paths
os.chdir('/app/data/')
# Add the path to the 'scripts' folder directly
sys.path.append('/app/data/')

# Save the interpolator to a file
with open('/app/data/Data/vz_interpolator/RVR.pkl', 'wb') as f:
    pickle.dump(interp_func, f)