Skip to content

Commit

Permalink
refactor: internally vectorize df._sample (GalacticDynamics#345)
Browse files Browse the repository at this point in the history
* refactor: internally vectorize df._sample
* fix: dphidr2 signature

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jun 12, 2024
1 parent e59335e commit 1cf28c4
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 68 deletions.
8 changes: 4 additions & 4 deletions src/galax/dynamics/_dynamics/mockstream/df/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""galax: Galactic Dynamix in Jax."""

from . import _base, _fardal, _progenitor
from . import _base, _fardal15, _progenitor
from ._base import *
from ._fardal import *
from ._fardal15 import *
from ._progenitor import *

__all__: list[str] = []
__all__ += _base.__all__
__all__ += _progenitor.__all__
__all__ += _fardal.__all__
__all__ += _fardal15.__all__

# Cleanup
del _base, _fardal, _progenitor
del _base, _fardal15, _progenitor
70 changes: 34 additions & 36 deletions src/galax/dynamics/_dynamics/mockstream/df/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""galax: Galactic Dynamix in Jax."""
"""Stream Distribution Functions for ejecting mock stream particles."""

__all__ = ["AbstractStreamDF"]

Expand All @@ -8,12 +8,10 @@

import equinox as eqx
import jax
import jax.random as jr
from jaxtyping import PRNGKeyArray
from plum import convert

import coordinax as cx
import quaxed.array_api as xp
from unxt import Quantity

import galax.coordinates as gc
Expand Down Expand Up @@ -45,7 +43,7 @@ def sample(
Parameters
----------
rng : :class:`jaxtyping.PRNGKeyArray`, positional-only
Pseudo-random number generator.
Pseudo-random number generator. Not split, used as is.
pot : :class:`~galax.potential.AbstractPotentialBase`, positional-only
The potential of the host galaxy.
prog_orbit : :class:`~galax.dynamics.Orbit`, positional-only
Expand Down Expand Up @@ -74,12 +72,22 @@ def sample(
>>> prog_orbit = pot.evaluate_orbit(w, t=Quantity([0, 1, 2], "Gyr"))
>>> stream_ic = df.sample(jr.key(0), pot, prog_orbit,
... prog_mass=Quantity(1e4, "Msun"))
>>> stream_ic
CompositePhaseSpacePosition({'lead': MockStreamArm(
q=CartesianPosition3D( ... ),
p=CartesianVelocity3D( ... ),
t=Quantity...,
release_time=Quantity... ),
'trail': MockStreamArm(
q=CartesianPosition3D( ... ),
p=CartesianVelocity3D( ... ),
t=Quantity...,
release_time=Quantity...
)})
"""
# Progenitor positions and times. The orbit times are used as the
# release times for the mock stream.
prog_orbit = prog_orbit.represent_as(cx.CartesianPosition3D)
xs = convert(prog_orbit.q, Quantity)
vs = convert(prog_orbit.p, Quantity)
ts = prog_orbit.t

# Progenitor mass
Expand All @@ -89,24 +97,14 @@ def sample(
else prog_mass
)

# Scan over the release times to generate the stream particle initial
# conditions at each release time.
def scan_fn(_: Carry, inputs: tuple[int, PRNGKeyArray]) -> tuple[Carry, Carry]:
i, key = inputs
out = self._sample(key, pot, xs[i], vs[i], mprog(ts[i]), ts[i])
return out, out

# TODO: use ``jax.vmap`` instead of ``jax.lax.scan``?
init_carry = (
xp.zeros_like(xs[0]),
xp.zeros_like(vs[0]),
xp.zeros_like(xs[0]),
xp.zeros_like(vs[0]),
x_lead, v_lead, x_trail, v_trail = self._sample(
rng,
pot,
convert(prog_orbit.q, Quantity),
convert(prog_orbit.p, Quantity),
mprog(ts),
ts,
)
subkeys = jr.split(rng, len(ts))
x_lead, v_lead, x_trail, v_trail = jax.lax.scan(
scan_fn, init_carry, (xp.arange(len(ts)), subkeys)
)[1]

mock_lead = MockStreamArm(
q=x_lead.to_units(pot.units["length"]),
Expand All @@ -127,12 +125,12 @@ def scan_fn(_: Carry, inputs: tuple[int, PRNGKeyArray]) -> tuple[Carry, Carry]:
@abc.abstractmethod
def _sample(
self,
rng: PRNGKeyArray,
pot: gp.AbstractPotentialBase,
x: gt.LengthVec3,
v: gt.SpeedVec3,
prog_mass: gt.FloatQScalar,
t: gt.FloatQScalar,
key: PRNGKeyArray,
potential: gp.AbstractPotentialBase,
x: gt.LengthBroadBatchVec3,
v: gt.SpeedBroadBatchVec3,
prog_mass: gt.BroadBatchFloatQScalar,
t: gt.BroadBatchFloatQScalar,
) -> tuple[
gt.LengthBatchVec3, gt.SpeedBatchVec3, gt.LengthBatchVec3, gt.SpeedBatchVec3
]:
Expand All @@ -142,22 +140,22 @@ def _sample(
----------
rng : :class:`jaxtyping.PRNGKeyArray`
Pseudo-random number generator.
pot : :class:`galax.potential.AbstractPotentialBase`
potential : :class:`galax.potential.AbstractPotentialBase`
The potential of the host galaxy.
x : Quantity[float, (3,), "length"]
x : Quantity[float, (*#batch, 3), "length"]
3d position (x, y, z)
v : Quantity[float, (3,), "speed"]
v : Quantity[float, (*#batch, 3), "speed"]
3d velocity (v_x, v_y, v_z)
prog_mass : Quantity[float, (), "mass"]
prog_mass : Quantity[float, (*#batch), "mass"]
Mass of the progenitor.
t : Quantity[float, (), "time"]
t : Quantity[float, (*#batch), "time"]
The release time of the stream particles.
Returns
-------
x_lead, v_lead: Quantity[float, (*shape, 3), "length" | "speed"]
x_lead, v_lead: Quantity[float, (*batch, 3), "length" | "speed"]
Position and velocity of the leading arm.
x_trail, v_trail : Quantity[float, (*shape, 3), "length" | "speed"]
x_trail, v_trail : Quantity[float, (*batch, 3), "length" | "speed"]
Position and velocity of the trailing arm.
"""
...
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from unxt import Quantity
from unxt.experimental import grad

import galax.potential as gp
import galax.typing as gt
from ._base import AbstractStreamDF
from galax.potential import AbstractPotentialBase

# ============================================================
# Constants
Expand All @@ -28,8 +28,8 @@
kz_bar = 0.0
kvz_bar = 0.0

sigma_kr = 0.5
sigma_kvphi = 0.5
sigma_kr = 0.5 # TODO: use actual Fardal values
sigma_kvphi = 0.5 # TODO: use actual Fardal values
sigma_kz = 0.5
sigma_kvz = 0.5

Expand All @@ -49,38 +49,41 @@ class FardalStreamDF(AbstractStreamDF):
def _sample(
self,
key: PRNGKeyArray,
potential: AbstractPotentialBase,
x: gt.LengthVec3,
v: gt.SpeedVec3,
prog_mass: gt.FloatQScalar,
t: gt.FloatQScalar,
) -> tuple[gt.LengthVec3, gt.SpeedVec3, gt.LengthVec3, gt.SpeedVec3]:
potential: gp.AbstractPotentialBase,
x: gt.LengthBroadBatchVec3,
v: gt.SpeedBroadBatchVec3,
prog_mass: gt.BroadBatchFloatQScalar,
t: gt.BroadBatchFloatQScalar,
) -> tuple[
gt.LengthBatchVec3, gt.SpeedBatchVec3, gt.LengthBatchVec3, gt.SpeedBatchVec3
]:
"""Generate stream particle initial conditions."""
# Random number generation
key1, key2, key3, key4 = jr.split(key, 4)

omega_val = orbital_angular_velocity_mag(x, v)
omega_val = orbital_angular_velocity_mag(x, v)[..., None]

# r-hat
r = xp.linalg.vector_norm(x, axis=-1)
r = xp.linalg.vector_norm(x, axis=-1, keepdims=True)
r_hat = x / r

r_tidal = tidal_radius(potential, x, v, prog_mass, t)
r_tidal = tidal_radius(potential, x, v, prog_mass, t)[..., None]
v_circ = omega_val * r_tidal # relative velocity

# z-hat
L_vec = qnp.cross(x, v)
z_hat = L_vec / xp.linalg.vector_norm(L_vec, axis=-1)
z_hat = L_vec / xp.linalg.vector_norm(L_vec, axis=-1, keepdims=True)

# phi-hat
phi_vec = v - xp.sum(v * r_hat) * r_hat
phi_hat = phi_vec / xp.linalg.vector_norm(phi_vec, axis=-1)
phi_vec = v - xp.sum(v * r_hat, axis=-1, keepdims=True) * r_hat
phi_hat = phi_vec / xp.linalg.vector_norm(phi_vec, axis=-1, keepdims=True)

# k vals
kr_samp = kr_bar + jr.normal(key1, (1,)) * sigma_kr
kvphi_samp = kr_samp * (kvphi_bar + jr.normal(key2, (1,)) * sigma_kvphi)
kz_samp = kz_bar + jr.normal(key3, (1,)) * sigma_kz
kvz_samp = kvz_bar + jr.normal(key4, (1,)) * sigma_kvz
shape = r_tidal.shape
kr_samp = kr_bar + jr.normal(key1, shape) * sigma_kr
kvphi_samp = kr_samp * (kvphi_bar + jr.normal(key2, shape) * sigma_kvphi)
kz_samp = kz_bar + jr.normal(key3, shape) * sigma_kz
kvz_samp = kvz_bar + jr.normal(key4, shape) * sigma_kvz

# Trailing arm
x_trail = x + r_tidal * (kr_samp * r_hat + kz_samp * z_hat)
Expand Down Expand Up @@ -115,15 +118,15 @@ def r_hat(x: gt.LengthBatchVec3, /) -> Shaped[Quantity[""], "*batch 3"]:

@partial(jax.jit, inline=True)
def dphidr(
potential: AbstractPotentialBase,
potential: gp.AbstractPotentialBase,
x: gt.LengthBatchVec3,
t: Shaped[Quantity["time"], ""],
) -> Shaped[Quantity["acceleration"], "*batch"]:
"""Compute the derivative of the potential at a position x.
Parameters
----------
potential: AbstractPotentialBase
potential : `galax.potential.AbstractPotentialBase`
The gravitational potential.
x: Quantity[float, (3,), 'length']
3d position (x, y, z)
Expand All @@ -139,16 +142,17 @@ def dphidr(


@partial(jax.jit)
@partial(qnp.vectorize, excluded=(0,), signature="(3),()->()")
def d2phidr2(
potential: AbstractPotentialBase, x: gt.LengthVec3, /, t: gt.TimeScalar
potential: gp.AbstractPotentialBase, x: gt.LengthVec3, t: gt.TimeScalar, /
) -> Shaped[Quantity["1/s^2"], ""]:
"""Compute the second derivative of the potential.
At a position x (in the simulation frame).
Parameters
----------
potential: AbstractPotentialBase
potential : `galax.potential.AbstractPotentialBase`
The gravitational potential.
x: Quantity[Any, (3,), 'length']
3d position (x, y, z) in [kpc]
Expand All @@ -166,7 +170,7 @@ def d2phidr2(
>>> from galax.potential import NFWPotential
>>> pot = NFWPotential(m=1e12, r_s=20.0, units="galactic")
>>> q = Quantity(xp.asarray([8.0, 0.0, 0.0]), "kpc")
>>> d2phidr2(pot, q, t=Quantity(0.0, "Myr"))
>>> d2phidr2(pot, q, Quantity(0.0, "Myr"))
Quantity['1'](Array(-0.0001747, dtype=float64), unit='1 / Myr2')
"""
rhat = r_hat(x)
Expand Down Expand Up @@ -234,7 +238,7 @@ def orbital_angular_velocity_mag(

@partial(jax.jit)
def tidal_radius(
potential: AbstractPotentialBase,
potential: gp.AbstractPotentialBase,
x: gt.LengthVec3,
v: gt.SpeedVec3,
/,
Expand All @@ -245,7 +249,7 @@ def tidal_radius(
Parameters
----------
potential: AbstractPotentialBase
potential : `galax.potential.AbstractPotentialBase`
The gravitational potential of the host.
x: Quantity[float, (3,), "length"]
3d position (x, y, z).
Expand Down Expand Up @@ -283,7 +287,7 @@ def tidal_radius(

@partial(jax.jit)
def lagrange_points(
potential: AbstractPotentialBase,
potential: gp.AbstractPotentialBase,
x: gt.LengthVec3,
v: gt.SpeedVec3,
prog_mass: gt.MassScalar,
Expand All @@ -293,7 +297,7 @@ def lagrange_points(
Parameters
----------
potential: AbstractPotentialBase
potential : `galax.potential.AbstractPotentialBase`
The gravitational potential of the host.
x: Quantity[float, (3,), "length"]
3d position (x, y, z)
Expand Down
1 change: 1 addition & 0 deletions src/galax/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,4 @@

SpeedVec3: TypeAlias = Shaped[Quantity["speed"], "3"]
SpeedBatchVec3: TypeAlias = Shaped[SpeedVec3, "*batch"]
SpeedBroadBatchVec3: TypeAlias = Shaped[SpeedVec3, "*#batch"]

0 comments on commit 1cf28c4

Please sign in to comment.