Skip to content

Commit

Permalink
MockstreamGenerator returns mockstreams (GalacticDynamics#51)
Browse files Browse the repository at this point in the history
Not arrays, as it currently does.

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jan 14, 2024
1 parent 49c36e6 commit 81169ed
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 39 deletions.
8 changes: 7 additions & 1 deletion src/galax/dynamics/mockstream/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.numpy as xp

from galax.dynamics._core import AbstractPhaseSpacePositionBase
from galax.typing import BatchFloatScalar, BatchVec7
from galax.typing import BatchFloatScalar, BatchVec3, BatchVec7
from galax.utils import partial_jit
from galax.utils._shape import atleast_batched, batched_shape
from galax.utils.dataclasses import converter_float_array
Expand All @@ -23,6 +23,12 @@ class MockStream(AbstractPhaseSpacePositionBase):
- GR 4-vector stuff
"""

q: BatchVec3 = eqx.field(converter=converter_float_array)
"""Positions (x, y, z)."""

p: BatchVec3 = eqx.field(converter=converter_float_array)
r"""Conjugate momenta (v_x, v_y, v_z)."""

release_time: BatchFloatScalar = eqx.field(converter=converter_float_array)
"""Release time of the stream particles [Myr]."""

Expand Down
81 changes: 43 additions & 38 deletions src/galax/dynamics/mockstream/_mockstream_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import equinox as eqx
import jax
import jax.numpy as xp
from jaxtyping import Float

from galax.dynamics._orbit import Orbit
from galax.integrate._base import AbstractIntegrator
from galax.integrate._builtin import DiffraxIntegrator
from galax.potential._potential.base import AbstractPotentialBase
from galax.typing import (
BatchVec6,
FloatScalar,
IntScalar,
TimeVector,
Expand All @@ -24,6 +24,7 @@
from galax.utils import partial_jit
from galax.utils._collections import ImmutableDict

from ._core import MockStream
from ._df import AbstractStreamDF

Carry: TypeAlias = tuple[IntScalar, VecN, VecN]
Expand Down Expand Up @@ -56,23 +57,14 @@ class MockStreamGenerator(eqx.Module): # type: ignore[misc]

# ==========================================================================

@partial_jit(static_argnames=("seed_num",))
def _run_scan(
self, ts: TimeVector, prog_w0: Vec6, prog_mass: FloatScalar, *, seed_num: int
) -> tuple[tuple[Float[Vec6, "time"], Float[Vec6, "time"]], Orbit]:
@partial_jit()
def _run_scan( # TODO: output shape depends on the input shape
self, ts: TimeVector, mock0_lead: MockStream, mock0_trail: MockStream
) -> tuple[BatchVec6, BatchVec6]:
"""Generate stellar stream by scanning over the release model/integration.
Better for CPU usage.
"""
# Integrate the progenitor orbit
prog_o = self.potential.integrate_orbit(
prog_w0, xp.min(ts), xp.max(ts), ts, integrator=self.progenitor_integrator
)

# Generate stream initial conditions along the integrated progenitor orbit
mock0_lead, mock0_trail = self.df.sample(
self.potential, prog_o, prog_mass, seed_num=seed_num
)
qp0_lead = mock0_lead.qp
qp0_trail = mock0_trail.qp

Expand All @@ -93,28 +85,20 @@ def integ_ics(ics: Vec6) -> VecN:

carry_init = (0, qp0_lead[0, :], qp0_trail[0, :])
particle_ids = xp.arange(len(qp0_lead))
lead_arm, trail_arm = jax.lax.scan(scan_fn, carry_init, particle_ids)[1]
return (lead_arm, trail_arm), prog_o
lead_arm_qp, trail_arm_qp = jax.lax.scan(scan_fn, carry_init, particle_ids)[1]

return lead_arm_qp, trail_arm_qp

@partial_jit(static_argnames=("seed_num",))
def _run_vmap(
self, ts: TimeVector, prog_w0: Vec6, prog_mass: FloatScalar, *, seed_num: int
) -> tuple[tuple[Float[Vec6, "time"], Float[Vec6, "time"]], Orbit]:
@partial_jit()
def _run_vmap( # TODO: output shape depends on the input shape
self, ts: TimeVector, mock0_lead: MockStream, mock0_trail: MockStream
) -> tuple[BatchVec6, BatchVec6]:
"""Generate stellar stream by vmapping over the release model/integration.
Better for GPU usage.
"""
# Integrate the progenitor orbit
prog_o = self.potential.integrate_orbit(
prog_w0, xp.min(ts), xp.max(ts), ts, integrator=self.progenitor_integrator
)

# Generate stream initial conditions along the integrated progenitor orbit
mock_lead, mock_trail = self.df.sample(
self.potential, prog_o, prog_mass, seed_num=seed_num
)
qp0_lead = mock_lead.qp
qp0_trail = mock_trail.qp
qp0_lead = mock0_lead.qp
qp0_trail = mock0_trail.qp
t_f = ts[-1] + 0.01

# TODO: make this a separated method
Expand All @@ -133,8 +117,8 @@ def single_particle_integrate(

particle_ids = xp.arange(len(qp0_lead))
integrator = jax.vmap(single_particle_integrate, in_axes=(0, 0, 0))
qp_lead, qp_trail = integrator(particle_ids, qp0_lead, qp0_trail)
return (qp_lead, qp_trail), prog_o
lead_arm_qp, trail_arm_qp = integrator(particle_ids, qp0_lead, qp0_trail)
return lead_arm_qp, trail_arm_qp

@partial_jit(static_argnames=("seed_num", "vmapped"))
def run(
Expand All @@ -145,10 +129,31 @@ def run(
*,
seed_num: int,
vmapped: bool = False,
) -> tuple[tuple[Float[Vec6, "time"], Float[Vec6, "time"]], Orbit]:
# TODO: figure out better return type: MockStream?
) -> tuple[tuple[MockStream, MockStream], Orbit]:
# Integrate the progenitor orbit
prog_o = self.potential.integrate_orbit(
prog_w0, xp.min(ts), xp.max(ts), ts, integrator=self.progenitor_integrator
)

# Generate stream initial conditions along the integrated progenitor orbit
mock0_lead, mock0_trail = self.df.sample(
self.potential, prog_o, prog_mass, seed_num=seed_num
)

if vmapped:
return self._run_vmap(ts, prog_w0, prog_mass, seed_num=seed_num)
return self._run_scan(ts, prog_w0, prog_mass, seed_num=seed_num)
lead_arm_qp, trail_arm_qp = self._run_vmap(ts, prog_w0, prog_mass)
else:
lead_arm_qp, trail_arm_qp = self._run_scan(ts, mock0_lead, mock0_trail)

lead_arm = MockStream(
q=lead_arm_qp[:, 0:3],
p=lead_arm_qp[:, 3:6],
release_time=mock0_lead.release_time,
)
trail_arm = MockStream(
q=trail_arm_qp[:, 0:3],
p=trail_arm_qp[:, 3:6],
release_time=mock0_trail.release_time,
)

# ==========================================================================
return (lead_arm, trail_arm), prog_o

0 comments on commit 81169ed

Please sign in to comment.