From 81169ed1cbb246ca45b8aae3bf6becb029d3f8a4 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Sun, 14 Jan 2024 17:53:03 -0500 Subject: [PATCH] MockstreamGenerator returns mockstreams (#51) Not arrays, as it currently does. Signed-off-by: nstarman --- src/galax/dynamics/mockstream/_core.py | 8 +- .../mockstream/_mockstream_generator.py | 81 ++++++++++--------- 2 files changed, 50 insertions(+), 39 deletions(-) diff --git a/src/galax/dynamics/mockstream/_core.py b/src/galax/dynamics/mockstream/_core.py index 0238f413..ff5c049f 100644 --- a/src/galax/dynamics/mockstream/_core.py +++ b/src/galax/dynamics/mockstream/_core.py @@ -6,7 +6,7 @@ import jax.numpy as xp from galax.dynamics._core import AbstractPhaseSpacePositionBase -from galax.typing import BatchFloatScalar, BatchVec7 +from galax.typing import BatchFloatScalar, BatchVec3, BatchVec7 from galax.utils import partial_jit from galax.utils._shape import atleast_batched, batched_shape from galax.utils.dataclasses import converter_float_array @@ -23,6 +23,12 @@ class MockStream(AbstractPhaseSpacePositionBase): - GR 4-vector stuff """ + q: BatchVec3 = eqx.field(converter=converter_float_array) + """Positions (x, y, z).""" + + p: BatchVec3 = eqx.field(converter=converter_float_array) + r"""Conjugate momenta (v_x, v_y, v_z).""" + release_time: BatchFloatScalar = eqx.field(converter=converter_float_array) """Release time of the stream particles [Myr].""" diff --git a/src/galax/dynamics/mockstream/_mockstream_generator.py b/src/galax/dynamics/mockstream/_mockstream_generator.py index edd73a45..3ac6a4fe 100644 --- a/src/galax/dynamics/mockstream/_mockstream_generator.py +++ b/src/galax/dynamics/mockstream/_mockstream_generator.py @@ -8,13 +8,13 @@ import equinox as eqx import jax import jax.numpy as xp -from jaxtyping import Float from galax.dynamics._orbit import Orbit from galax.integrate._base import AbstractIntegrator from galax.integrate._builtin import DiffraxIntegrator from galax.potential._potential.base import AbstractPotentialBase from galax.typing import ( + BatchVec6, FloatScalar, IntScalar, TimeVector, @@ -24,6 +24,7 @@ from galax.utils import partial_jit from galax.utils._collections import ImmutableDict +from ._core import MockStream from ._df import AbstractStreamDF Carry: TypeAlias = tuple[IntScalar, VecN, VecN] @@ -56,23 +57,14 @@ class MockStreamGenerator(eqx.Module): # type: ignore[misc] # ========================================================================== - @partial_jit(static_argnames=("seed_num",)) - def _run_scan( - self, ts: TimeVector, prog_w0: Vec6, prog_mass: FloatScalar, *, seed_num: int - ) -> tuple[tuple[Float[Vec6, "time"], Float[Vec6, "time"]], Orbit]: + @partial_jit() + def _run_scan( # TODO: output shape depends on the input shape + self, ts: TimeVector, mock0_lead: MockStream, mock0_trail: MockStream + ) -> tuple[BatchVec6, BatchVec6]: """Generate stellar stream by scanning over the release model/integration. Better for CPU usage. """ - # Integrate the progenitor orbit - prog_o = self.potential.integrate_orbit( - prog_w0, xp.min(ts), xp.max(ts), ts, integrator=self.progenitor_integrator - ) - - # Generate stream initial conditions along the integrated progenitor orbit - mock0_lead, mock0_trail = self.df.sample( - self.potential, prog_o, prog_mass, seed_num=seed_num - ) qp0_lead = mock0_lead.qp qp0_trail = mock0_trail.qp @@ -93,28 +85,20 @@ def integ_ics(ics: Vec6) -> VecN: carry_init = (0, qp0_lead[0, :], qp0_trail[0, :]) particle_ids = xp.arange(len(qp0_lead)) - lead_arm, trail_arm = jax.lax.scan(scan_fn, carry_init, particle_ids)[1] - return (lead_arm, trail_arm), prog_o + lead_arm_qp, trail_arm_qp = jax.lax.scan(scan_fn, carry_init, particle_ids)[1] + + return lead_arm_qp, trail_arm_qp - @partial_jit(static_argnames=("seed_num",)) - def _run_vmap( - self, ts: TimeVector, prog_w0: Vec6, prog_mass: FloatScalar, *, seed_num: int - ) -> tuple[tuple[Float[Vec6, "time"], Float[Vec6, "time"]], Orbit]: + @partial_jit() + def _run_vmap( # TODO: output shape depends on the input shape + self, ts: TimeVector, mock0_lead: MockStream, mock0_trail: MockStream + ) -> tuple[BatchVec6, BatchVec6]: """Generate stellar stream by vmapping over the release model/integration. Better for GPU usage. """ - # Integrate the progenitor orbit - prog_o = self.potential.integrate_orbit( - prog_w0, xp.min(ts), xp.max(ts), ts, integrator=self.progenitor_integrator - ) - - # Generate stream initial conditions along the integrated progenitor orbit - mock_lead, mock_trail = self.df.sample( - self.potential, prog_o, prog_mass, seed_num=seed_num - ) - qp0_lead = mock_lead.qp - qp0_trail = mock_trail.qp + qp0_lead = mock0_lead.qp + qp0_trail = mock0_trail.qp t_f = ts[-1] + 0.01 # TODO: make this a separated method @@ -133,8 +117,8 @@ def single_particle_integrate( particle_ids = xp.arange(len(qp0_lead)) integrator = jax.vmap(single_particle_integrate, in_axes=(0, 0, 0)) - qp_lead, qp_trail = integrator(particle_ids, qp0_lead, qp0_trail) - return (qp_lead, qp_trail), prog_o + lead_arm_qp, trail_arm_qp = integrator(particle_ids, qp0_lead, qp0_trail) + return lead_arm_qp, trail_arm_qp @partial_jit(static_argnames=("seed_num", "vmapped")) def run( @@ -145,10 +129,31 @@ def run( *, seed_num: int, vmapped: bool = False, - ) -> tuple[tuple[Float[Vec6, "time"], Float[Vec6, "time"]], Orbit]: - # TODO: figure out better return type: MockStream? + ) -> tuple[tuple[MockStream, MockStream], Orbit]: + # Integrate the progenitor orbit + prog_o = self.potential.integrate_orbit( + prog_w0, xp.min(ts), xp.max(ts), ts, integrator=self.progenitor_integrator + ) + + # Generate stream initial conditions along the integrated progenitor orbit + mock0_lead, mock0_trail = self.df.sample( + self.potential, prog_o, prog_mass, seed_num=seed_num + ) + if vmapped: - return self._run_vmap(ts, prog_w0, prog_mass, seed_num=seed_num) - return self._run_scan(ts, prog_w0, prog_mass, seed_num=seed_num) + lead_arm_qp, trail_arm_qp = self._run_vmap(ts, prog_w0, prog_mass) + else: + lead_arm_qp, trail_arm_qp = self._run_scan(ts, mock0_lead, mock0_trail) + + lead_arm = MockStream( + q=lead_arm_qp[:, 0:3], + p=lead_arm_qp[:, 3:6], + release_time=mock0_lead.release_time, + ) + trail_arm = MockStream( + q=trail_arm_qp[:, 0:3], + p=trail_arm_qp[:, 3:6], + release_time=mock0_trail.release_time, + ) - # ========================================================================== + return (lead_arm, trail_arm), prog_o