From 407ed56c4b6bdf4ba32908aea0fb9032f900e5e7 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Wed, 8 Nov 2023 15:41:06 -0500 Subject: [PATCH] Cleanup mockstream (#5) * clean up mockstream * batch df sampling * comment out WIP code Signed-off-by: nstarman --- src/galdynamix/dynamics/mockstream/_df.py | 101 +++++++++++++- .../dynamics/mockstream/_mockstream.py | 126 +++++++----------- 2 files changed, 144 insertions(+), 83 deletions(-) diff --git a/src/galdynamix/dynamics/mockstream/_df.py b/src/galdynamix/dynamics/mockstream/_df.py index e52f160c..1460069f 100644 --- a/src/galdynamix/dynamics/mockstream/_df.py +++ b/src/galdynamix/dynamics/mockstream/_df.py @@ -6,6 +6,7 @@ __all__ = ["BaseStreamDF", "FardalStreamDF"] import abc +from typing import TYPE_CHECKING, Any, TypeAlias import equinox as eqx import jax @@ -15,6 +16,10 @@ from galdynamix.potential._potential.base import AbstractPotentialBase from galdynamix.utils import partial_jit +if TYPE_CHECKING: + _wifT: TypeAlias = tuple[jt.Array, jt.Array, jt.Array, jt.Array] + _carryT: TypeAlias = tuple[int, jt.Array, jt.Array, jt.Array, jt.Array] + class BaseStreamDF(eqx.Module): # type: ignore[misc] lead: bool = eqx.field(default=True, static=True) @@ -25,20 +30,103 @@ def __post_init__(self) -> None: msg = "You must generate either leading or trailing tails (or both!)" raise ValueError(msg) - @abc.abstractmethod + @partial_jit(static_argnames=("seed_num",)) def sample( + self, + # <\ parts of gala's ``prog_orbit`` + potential: AbstractPotentialBase, + prog_ws: jt.Array, + ts: jt.Numeric, + # /> + prog_mass: jt.Numeric, + *, + seed_num: int, + ) -> tuple[jt.Array, jt.Array, jt.Array, jt.Array]: + """Generate stream particle initial conditions. + + Parameters + ---------- + potential : AbstractPotentialBase + The potential of the host galaxy. + prog_ws : Array[(N, 6), float] + Columns are (x, y, z) [kpc], (v_x, v_y, v_z) [kpc/Myr] + Rows are at times `ts`. + prog_mass : Numeric + Mass of the progenitor in [Msol]. + TODO: allow this to be an array or function of time. + ts : Numeric + Times in [Myr] + + seed_num : int, keyword-only + PRNG seed + + Returns + ------- + x_lead, x_trail, v_lead, v_trail : Array + Positions and velocities of the leading and trailing tails. + """ + + def scan_fn(carry: _carryT, t: Any) -> tuple[_carryT, _wifT]: + i = carry[0] + output = self._sample( + potential, + prog_ws[i, :3], + prog_ws[i, 3:], + prog_mass, + i, + t, + seed_num=seed_num, + ) + return (i + 1, *output), tuple(output) # type: ignore[return-value] + + init_carry = ( + 0, + xp.array([0.0, 0.0, 0.0]), + xp.array([0.0, 0.0, 0.0]), + xp.array([0.0, 0.0, 0.0]), + xp.array([0.0, 0.0, 0.0]), + ) + x_lead, x_trail, v_lead, v_trail = jax.lax.scan(scan_fn, init_carry, ts[1:])[1] + return x_lead, x_trail, v_lead, v_trail + + @abc.abstractmethod + def _sample( self, potential: AbstractPotentialBase, x: jt.Array, v: jt.Array, - prog_mass: jt.Array, + prog_mass: jt.Numeric, i: int, - t: jt.Array, + t: jt.Numeric, *, seed_num: int, ) -> tuple[jt.Array, jt.Array, jt.Array, jt.Array]: - """Sample the DF.""" - raise NotImplementedError + """Generate stream particle initial conditions. + + Parameters + ---------- + potential : AbstractPotentialBase + The potential of the host galaxy. + x : Array + 3d position (x, y, z) in [kpc] + v : Array + 3d velocity (v_x, v_y, v_z) in [kpc/Myr] + prog_mass : Numeric + Mass of the progenitor in [Msol] + t : Numeric + Time in [Myr] + + i : int + PRNG multiplier + seed_num : int + PRNG seed + + Returns + ------- + x_lead, x_trail, v_lead, v_trail : Array + Positions and velocities of the leading and trailing tails. + """ + ... # ========================================================================== @@ -46,8 +134,9 @@ def sample( class FardalStreamDF(BaseStreamDF): @partial_jit(static_argnames=("seed_num",)) - def sample( + def _sample( self, + # parts of gala's ``prog_orbit`` potential: AbstractPotentialBase, x: jt.Array, v: jt.Array, diff --git a/src/galdynamix/dynamics/mockstream/_mockstream.py b/src/galdynamix/dynamics/mockstream/_mockstream.py index 50d22c61..a40c9470 100644 --- a/src/galdynamix/dynamics/mockstream/_mockstream.py +++ b/src/galdynamix/dynamics/mockstream/_mockstream.py @@ -5,7 +5,7 @@ __all__ = ["MockStreamGenerator"] -from typing import TYPE_CHECKING, Any, TypeAlias +from typing import TYPE_CHECKING, TypeAlias import equinox as eqx import jax @@ -25,125 +25,97 @@ class MockStreamGenerator(eqx.Module): # type: ignore[misc] df: BaseStreamDF potential: AbstractPotentialBase - progenitor_potential: AbstractPotentialBase | None = None + # progenitor_potential: AbstractPotentialBase | None = None - @property - def self_gravity(self) -> bool: - return self.progenitor_potential is not None + # @property + # def self_gravity(self) -> bool: + # return self.progenitor_potential is not None # ========================================================================== - @partial_jit(static_argnames=("seed_num",)) - def _stream_ics( - self, ts: jt.Array, w0: jt.Array, mass: jt.Array, *, seed_num: int - ) -> jt.Array: - """Stream Initial Conditions. - - Parameters - ---------- - ts : array_like - Array of times to release particles. - w0 : array_like - q, p of the progenitor. - mass : array_like - Mass of the progenitor. - """ - ws = self.potential.integrate_orbit(w0, xp.min(ts), xp.max(ts), ts) - - def scan_fun(carry: _carryT, t: Any) -> tuple[_carryT, _wifT]: - i = carry[0] - output = self.df.sample( - self.potential, ws[i, :3], ws[i, 3:], mass, i, t, seed_num=seed_num - ) - return (i + 1, *output), tuple(output) # type: ignore[return-value] - - init_carry = ( - 0, - xp.array([0.0, 0.0, 0.0]), - xp.array([0.0, 0.0, 0.0]), - xp.array([0.0, 0.0, 0.0]), - xp.array([0.0, 0.0, 0.0]), - ) - return jax.lax.scan(scan_fun, init_carry, ts[1:])[1] - @partial_jit(static_argnames=("seed_num",)) def _run_scan( self, ts: jt.Array, prog_w0: jt.Array, prog_mass: jt.Array, *, seed_num: int - ) -> tuple[jt.Array, jt.Array]: - """ - Generate stellar stream by scanning over the release model/integration. Better for CPU usage. + ) -> tuple[tuple[jt.Array, jt.Array], jt.Array]: + """Generate stellar stream by scanning over the release model/integration. + + Better for CPU usage. """ - q_close, q_far, p_close, p_far = self._stream_ics( - ts, prog_w0, prog_mass, seed_num=seed_num + # Integrate the progenitor orbit + prog_ws = self.potential.integrate_orbit(prog_w0, xp.min(ts), xp.max(ts), ts) + + # Generate stream initial conditions along the integrated progenitor orbit + x_lead, x_trail, v_lead, v_trail = self.df.sample( + self.potential, prog_ws, ts, prog_mass, seed_num=seed_num ) - # TODO: make this a separated method - @jax.jit # type: ignore[misc] - def scan_fun( + def scan_fn( carry: _carryT, particle_idx: int ) -> tuple[_carryT, tuple[jt.Array, jt.Array]]: - i, q_close_i, q_far_i, p_close_i, p_far_i = carry - w0_close_i = xp.hstack([q_close_i, p_close_i]) - w0_far_i = xp.hstack([q_far_i, p_far_i]) - w0_lead_trail = xp.vstack([w0_close_i, w0_far_i]) + i, x_lead_i, x_trail_i, v_lead_i, v_trail_i = carry + w0_lead_i = xp.hstack([x_lead_i, v_lead_i]) + w0_trail_i = xp.hstack([x_trail_i, v_trail_i]) + w0_lead_trail = xp.vstack([w0_lead_i, w0_trail_i]) minval, maxval = ts[i], ts[-1] integ_ics = lambda ics: self.potential.integrate_orbit( # noqa: E731 ics, minval, maxval, None )[0] # vmap over leading and trailing arm - w_close, w_far = jax.vmap(integ_ics, in_axes=(0,))(w0_lead_trail) + w_lead, w_trail = jax.vmap(integ_ics, in_axes=(0,))(w0_lead_trail) carry_out = ( i + 1, - q_close[i + 1, :], - q_far[i + 1, :], - p_close[i + 1, :], - p_far[i + 1, :], + x_lead[i + 1, :], + x_trail[i + 1, :], + v_lead[i + 1, :], + v_trail[i + 1, :], ) - return carry_out, (w_close, w_far) + return carry_out, (w_lead, w_trail) - carry_init = (0, q_close[0, :], q_far[0, :], p_close[0, :], p_far[0, :]) - particle_ids = xp.arange(len(q_close)) - lead_arm, trail_arm = jax.lax.scan(scan_fun, carry_init, particle_ids)[1] - return lead_arm, trail_arm + carry_init = (0, x_lead[0, :], x_trail[0, :], v_lead[0, :], v_trail[0, :]) + particle_ids = xp.arange(len(x_lead)) + lead_arm, trail_arm = jax.lax.scan(scan_fn, carry_init, particle_ids)[1] + return (lead_arm, trail_arm), prog_ws @partial_jit(static_argnames=("seed_num",)) def _run_vmap( self, ts: jt.Array, prog_w0: jt.Array, prog_mass: jt.Array, *, seed_num: int - ) -> tuple[jt.Array, jt.Array]: + ) -> tuple[tuple[jt.Array, jt.Array], jt.Array]: """ Generate stellar stream by vmapping over the release model/integration. Better for GPU usage. """ - q_close_arr, q_far_arr, p_close_arr, p_far_arr = self._stream_ics( - ts, prog_w0, prog_mass, seed_num=seed_num + # Integrate the progenitor orbit + prog_ws = self.potential.integrate_orbit(prog_w0, xp.min(ts), xp.max(ts), ts) + + # Generate stream initial conditions along the integrated progenitor orbit + x_lead, x_trail, v_lead, v_trail = self.df.sample( + self.potential, prog_ws, ts, prog_mass, seed_num=seed_num ) # TODO: make this a separated method @jax.jit # type: ignore[misc] def single_particle_integrate( i: int, - q_close_i: jt.Array, - q_far_i: jt.Array, - p_close_i: jt.Array, - p_far_i: jt.Array, + x_lead_i: jt.Array, + x_trail_i: jt.Array, + v_lead_i: jt.Array, + v_trail_i: jt.Array, ) -> tuple[jt.Array, jt.Array]: - w0_close_i = xp.hstack([q_close_i, p_close_i]) - w0_far_i = xp.hstack([q_far_i, p_far_i]) + w0_lead_i = xp.hstack([x_lead_i, v_lead_i]) + w0_trail_i = xp.hstack([x_trail_i, v_trail_i]) t_i = ts[i] t_f = ts[-1] + 0.01 - w_close = self.integrate_orbit(w0_close_i, t_i, t_f, None)[0] - w_far = self.integrate_orbit(w0_far_i, t_i, t_f, None)[0] + w_lead = self.integrate_orbit(w0_lead_i, t_i, t_f, None)[0] + w_trail = self.integrate_orbit(w0_trail_i, t_i, t_f, None)[0] - return w_close, w_far + return w_lead, w_trail - particle_ids = xp.arange(len(q_close_arr)) + particle_ids = xp.arange(len(x_lead)) integrator = jax.vmap(single_particle_integrate, in_axes=(0, 0, 0, 0, 0)) - w_close, w_far = integrator( - particle_ids, q_close_arr, q_far_arr, p_close_arr, p_far_arr - ) - return w_close, w_far + w_lead, w_trail = integrator(particle_ids, x_lead, x_trail, v_lead, v_trail) + return (w_lead, w_trail), prog_ws @partial_jit(static_argnames=("seed_num", "vmapped")) def run(