Skip to content

Commit

Permalink
Cleanup mockstream (#5)
Browse files Browse the repository at this point in the history
* clean up mockstream
* batch df sampling
* comment out WIP code

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Nov 8, 2023
1 parent 4bfacc1 commit 407ed56
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 83 deletions.
101 changes: 95 additions & 6 deletions src/galdynamix/dynamics/mockstream/_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
__all__ = ["BaseStreamDF", "FardalStreamDF"]

import abc
from typing import TYPE_CHECKING, Any, TypeAlias

import equinox as eqx
import jax
Expand All @@ -15,6 +16,10 @@
from galdynamix.potential._potential.base import AbstractPotentialBase
from galdynamix.utils import partial_jit

if TYPE_CHECKING:
_wifT: TypeAlias = tuple[jt.Array, jt.Array, jt.Array, jt.Array]
_carryT: TypeAlias = tuple[int, jt.Array, jt.Array, jt.Array, jt.Array]


class BaseStreamDF(eqx.Module): # type: ignore[misc]
lead: bool = eqx.field(default=True, static=True)
Expand All @@ -25,29 +30,113 @@ def __post_init__(self) -> None:
msg = "You must generate either leading or trailing tails (or both!)"
raise ValueError(msg)

@abc.abstractmethod
@partial_jit(static_argnames=("seed_num",))
def sample(
self,
# <\ parts of gala's ``prog_orbit``
potential: AbstractPotentialBase,
prog_ws: jt.Array,
ts: jt.Numeric,
# />
prog_mass: jt.Numeric,
*,
seed_num: int,
) -> tuple[jt.Array, jt.Array, jt.Array, jt.Array]:
"""Generate stream particle initial conditions.
Parameters
----------
potential : AbstractPotentialBase
The potential of the host galaxy.
prog_ws : Array[(N, 6), float]
Columns are (x, y, z) [kpc], (v_x, v_y, v_z) [kpc/Myr]
Rows are at times `ts`.
prog_mass : Numeric
Mass of the progenitor in [Msol].
TODO: allow this to be an array or function of time.
ts : Numeric
Times in [Myr]
seed_num : int, keyword-only
PRNG seed
Returns
-------
x_lead, x_trail, v_lead, v_trail : Array
Positions and velocities of the leading and trailing tails.
"""

def scan_fn(carry: _carryT, t: Any) -> tuple[_carryT, _wifT]:
i = carry[0]
output = self._sample(
potential,
prog_ws[i, :3],
prog_ws[i, 3:],
prog_mass,
i,
t,
seed_num=seed_num,
)
return (i + 1, *output), tuple(output) # type: ignore[return-value]

init_carry = (
0,
xp.array([0.0, 0.0, 0.0]),
xp.array([0.0, 0.0, 0.0]),
xp.array([0.0, 0.0, 0.0]),
xp.array([0.0, 0.0, 0.0]),
)
x_lead, x_trail, v_lead, v_trail = jax.lax.scan(scan_fn, init_carry, ts[1:])[1]
return x_lead, x_trail, v_lead, v_trail

@abc.abstractmethod
def _sample(
self,
potential: AbstractPotentialBase,
x: jt.Array,
v: jt.Array,
prog_mass: jt.Array,
prog_mass: jt.Numeric,
i: int,
t: jt.Array,
t: jt.Numeric,
*,
seed_num: int,
) -> tuple[jt.Array, jt.Array, jt.Array, jt.Array]:
"""Sample the DF."""
raise NotImplementedError
"""Generate stream particle initial conditions.
Parameters
----------
potential : AbstractPotentialBase
The potential of the host galaxy.
x : Array
3d position (x, y, z) in [kpc]
v : Array
3d velocity (v_x, v_y, v_z) in [kpc/Myr]
prog_mass : Numeric
Mass of the progenitor in [Msol]
t : Numeric
Time in [Myr]
i : int
PRNG multiplier
seed_num : int
PRNG seed
Returns
-------
x_lead, x_trail, v_lead, v_trail : Array
Positions and velocities of the leading and trailing tails.
"""
...


# ==========================================================================


class FardalStreamDF(BaseStreamDF):
@partial_jit(static_argnames=("seed_num",))
def sample(
def _sample(
self,
# parts of gala's ``prog_orbit``
potential: AbstractPotentialBase,
x: jt.Array,
v: jt.Array,
Expand Down
126 changes: 49 additions & 77 deletions src/galdynamix/dynamics/mockstream/_mockstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

__all__ = ["MockStreamGenerator"]

from typing import TYPE_CHECKING, Any, TypeAlias
from typing import TYPE_CHECKING, TypeAlias

import equinox as eqx
import jax
Expand All @@ -25,125 +25,97 @@
class MockStreamGenerator(eqx.Module): # type: ignore[misc]
df: BaseStreamDF
potential: AbstractPotentialBase
progenitor_potential: AbstractPotentialBase | None = None
# progenitor_potential: AbstractPotentialBase | None = None

@property
def self_gravity(self) -> bool:
return self.progenitor_potential is not None
# @property
# def self_gravity(self) -> bool:
# return self.progenitor_potential is not None

# ==========================================================================

@partial_jit(static_argnames=("seed_num",))
def _stream_ics(
self, ts: jt.Array, w0: jt.Array, mass: jt.Array, *, seed_num: int
) -> jt.Array:
"""Stream Initial Conditions.
Parameters
----------
ts : array_like
Array of times to release particles.
w0 : array_like
q, p of the progenitor.
mass : array_like
Mass of the progenitor.
"""
ws = self.potential.integrate_orbit(w0, xp.min(ts), xp.max(ts), ts)

def scan_fun(carry: _carryT, t: Any) -> tuple[_carryT, _wifT]:
i = carry[0]
output = self.df.sample(
self.potential, ws[i, :3], ws[i, 3:], mass, i, t, seed_num=seed_num
)
return (i + 1, *output), tuple(output) # type: ignore[return-value]

init_carry = (
0,
xp.array([0.0, 0.0, 0.0]),
xp.array([0.0, 0.0, 0.0]),
xp.array([0.0, 0.0, 0.0]),
xp.array([0.0, 0.0, 0.0]),
)
return jax.lax.scan(scan_fun, init_carry, ts[1:])[1]

@partial_jit(static_argnames=("seed_num",))
def _run_scan(
self, ts: jt.Array, prog_w0: jt.Array, prog_mass: jt.Array, *, seed_num: int
) -> tuple[jt.Array, jt.Array]:
"""
Generate stellar stream by scanning over the release model/integration. Better for CPU usage.
) -> tuple[tuple[jt.Array, jt.Array], jt.Array]:
"""Generate stellar stream by scanning over the release model/integration.
Better for CPU usage.
"""
q_close, q_far, p_close, p_far = self._stream_ics(
ts, prog_w0, prog_mass, seed_num=seed_num
# Integrate the progenitor orbit
prog_ws = self.potential.integrate_orbit(prog_w0, xp.min(ts), xp.max(ts), ts)

# Generate stream initial conditions along the integrated progenitor orbit
x_lead, x_trail, v_lead, v_trail = self.df.sample(
self.potential, prog_ws, ts, prog_mass, seed_num=seed_num
)

# TODO: make this a separated method
@jax.jit # type: ignore[misc]
def scan_fun(
def scan_fn(
carry: _carryT, particle_idx: int
) -> tuple[_carryT, tuple[jt.Array, jt.Array]]:
i, q_close_i, q_far_i, p_close_i, p_far_i = carry
w0_close_i = xp.hstack([q_close_i, p_close_i])
w0_far_i = xp.hstack([q_far_i, p_far_i])
w0_lead_trail = xp.vstack([w0_close_i, w0_far_i])
i, x_lead_i, x_trail_i, v_lead_i, v_trail_i = carry
w0_lead_i = xp.hstack([x_lead_i, v_lead_i])
w0_trail_i = xp.hstack([x_trail_i, v_trail_i])
w0_lead_trail = xp.vstack([w0_lead_i, w0_trail_i])

minval, maxval = ts[i], ts[-1]
integ_ics = lambda ics: self.potential.integrate_orbit( # noqa: E731
ics, minval, maxval, None
)[0]
# vmap over leading and trailing arm
w_close, w_far = jax.vmap(integ_ics, in_axes=(0,))(w0_lead_trail)
w_lead, w_trail = jax.vmap(integ_ics, in_axes=(0,))(w0_lead_trail)
carry_out = (
i + 1,
q_close[i + 1, :],
q_far[i + 1, :],
p_close[i + 1, :],
p_far[i + 1, :],
x_lead[i + 1, :],
x_trail[i + 1, :],
v_lead[i + 1, :],
v_trail[i + 1, :],
)
return carry_out, (w_close, w_far)
return carry_out, (w_lead, w_trail)

carry_init = (0, q_close[0, :], q_far[0, :], p_close[0, :], p_far[0, :])
particle_ids = xp.arange(len(q_close))
lead_arm, trail_arm = jax.lax.scan(scan_fun, carry_init, particle_ids)[1]
return lead_arm, trail_arm
carry_init = (0, x_lead[0, :], x_trail[0, :], v_lead[0, :], v_trail[0, :])
particle_ids = xp.arange(len(x_lead))
lead_arm, trail_arm = jax.lax.scan(scan_fn, carry_init, particle_ids)[1]
return (lead_arm, trail_arm), prog_ws

@partial_jit(static_argnames=("seed_num",))
def _run_vmap(
self, ts: jt.Array, prog_w0: jt.Array, prog_mass: jt.Array, *, seed_num: int
) -> tuple[jt.Array, jt.Array]:
) -> tuple[tuple[jt.Array, jt.Array], jt.Array]:
"""
Generate stellar stream by vmapping over the release model/integration. Better for GPU usage.
"""
q_close_arr, q_far_arr, p_close_arr, p_far_arr = self._stream_ics(
ts, prog_w0, prog_mass, seed_num=seed_num
# Integrate the progenitor orbit
prog_ws = self.potential.integrate_orbit(prog_w0, xp.min(ts), xp.max(ts), ts)

# Generate stream initial conditions along the integrated progenitor orbit
x_lead, x_trail, v_lead, v_trail = self.df.sample(
self.potential, prog_ws, ts, prog_mass, seed_num=seed_num
)

# TODO: make this a separated method
@jax.jit # type: ignore[misc]
def single_particle_integrate(
i: int,
q_close_i: jt.Array,
q_far_i: jt.Array,
p_close_i: jt.Array,
p_far_i: jt.Array,
x_lead_i: jt.Array,
x_trail_i: jt.Array,
v_lead_i: jt.Array,
v_trail_i: jt.Array,
) -> tuple[jt.Array, jt.Array]:
w0_close_i = xp.hstack([q_close_i, p_close_i])
w0_far_i = xp.hstack([q_far_i, p_far_i])
w0_lead_i = xp.hstack([x_lead_i, v_lead_i])
w0_trail_i = xp.hstack([x_trail_i, v_trail_i])
t_i = ts[i]
t_f = ts[-1] + 0.01

w_close = self.integrate_orbit(w0_close_i, t_i, t_f, None)[0]
w_far = self.integrate_orbit(w0_far_i, t_i, t_f, None)[0]
w_lead = self.integrate_orbit(w0_lead_i, t_i, t_f, None)[0]
w_trail = self.integrate_orbit(w0_trail_i, t_i, t_f, None)[0]

return w_close, w_far
return w_lead, w_trail

particle_ids = xp.arange(len(q_close_arr))
particle_ids = xp.arange(len(x_lead))

integrator = jax.vmap(single_particle_integrate, in_axes=(0, 0, 0, 0, 0))
w_close, w_far = integrator(
particle_ids, q_close_arr, q_far_arr, p_close_arr, p_far_arr
)
return w_close, w_far
w_lead, w_trail = integrator(particle_ids, x_lead, x_trail, v_lead, v_trail)
return (w_lead, w_trail), prog_ws

@partial_jit(static_argnames=("seed_num", "vmapped"))
def run(
Expand Down

0 comments on commit 407ed56

Please sign in to comment.